From bdc882f6cef98e36addca05666edad5e9a7f2d97 Mon Sep 17 00:00:00 2001 From: Lyon Hill Date: Mon, 11 May 2020 15:04:11 -0600 Subject: [PATCH] feat(session): Build out a new session service (#17950) This new session service has the ability to work independant of other systems it relies on having its own store type which should allow us to be more flexible then using the built in kv system. I have included an in mem session store. --- inmem/session_store.go | 86 ++++++++++++ inmem/session_store_test.go | 162 ++++++++++++++++++++++ kv/session.go | 3 +- session.go | 2 + session/errors.go | 14 ++ session/http_server.go | 194 +++++++++++++++++++++++++++ session/http_server_test.go | 98 ++++++++++++++ session/middleware_logging.go | 81 +++++++++++ session/middleware_metrics.go | 56 ++++++++ session/service.go | 175 ++++++++++++++++++++++++ session/service_test.go | 49 +++++++ session/storage.go | 136 +++++++++++++++++++ session/storage_test.go | 120 +++++++++++++++++ task/backend/analytical_storage.go | 2 +- tenant/middleware_onboarding_auth.go | 4 +- tenant/storage.go | 5 + testing/session.go | 7 +- 17 files changed, 1186 insertions(+), 8 deletions(-) create mode 100644 inmem/session_store.go create mode 100644 inmem/session_store_test.go create mode 100644 session/errors.go create mode 100644 session/http_server.go create mode 100644 session/http_server_test.go create mode 100644 session/middleware_logging.go create mode 100644 session/middleware_metrics.go create mode 100644 session/service.go create mode 100644 session/service_test.go create mode 100644 session/storage.go create mode 100644 session/storage_test.go diff --git a/inmem/session_store.go b/inmem/session_store.go new file mode 100644 index 0000000000..ed7167d7af --- /dev/null +++ b/inmem/session_store.go @@ -0,0 +1,86 @@ +package inmem + +import ( + "errors" + "sync" + "time" +) + +type SessionStore struct { + data map[string]string + + timers map[string]*time.Timer + + mu sync.RWMutex +} + +func NewSessionStore() *SessionStore { + return &SessionStore{ + data: map[string]string{}, + timers: map[string]*time.Timer{}, + } +} + +func (s *SessionStore) Set(key, val string, expireAt time.Time) error { + if !expireAt.IsZero() && expireAt.Before(time.Now()) { + // key is already expired. no problem + return nil + } + + s.mu.Lock() + s.data[key] = val + s.mu.Unlock() + + if !expireAt.IsZero() { + return s.ExpireAt(key, expireAt) + } + return nil +} + +func (s *SessionStore) Get(key string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.data[key], nil +} + +func (s *SessionStore) Delete(key string) error { + s.mu.Lock() + defer s.mu.Unlock() + timer := s.timers[key] + if timer != nil { + timer.Stop() + } + + delete(s.data, key) + delete(s.timers, key) + return nil +} + +func (s *SessionStore) ExpireAt(key string, expireAt time.Time) error { + s.mu.Lock() + + existingTimer, ok := s.timers[key] + if ok { + if !existingTimer.Stop() { + return errors.New("session has expired") + } + + } + + duration := time.Until(expireAt) + if duration <= 0 { + s.mu.Unlock() + s.Delete(key) + return nil + } + s.timers[key] = time.AfterFunc(time.Until(expireAt), s.timerExpireFunc(key)) + s.mu.Unlock() + return nil +} + +func (s *SessionStore) timerExpireFunc(key string) func() { + return func() { + s.Delete(key) + } +} diff --git a/inmem/session_store_test.go b/inmem/session_store_test.go new file mode 100644 index 0000000000..d4fb46e605 --- /dev/null +++ b/inmem/session_store_test.go @@ -0,0 +1,162 @@ +package inmem_test + +import ( + "testing" + "time" + + "github.com/influxdata/influxdb/v2/inmem" +) + +func TestSessionSet(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + err = st.Set("hi", "enemy", time.Time{}) + if err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "enemy" { + t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word) + } +} +func TestSessionGet(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Now().Add(time.Second)) + if err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word) + } + + time.Sleep(time.Second * 2) + + word, err = st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } +} +func TestSessionDelete(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + if err := st.Delete("hi"); err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Delete("hi"); err != nil { + t.Fatal(err) + } +} +func TestSessionExpireAt(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("hi", time.Now().Add(-20)); err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Set("hello", "friend", time.Time{}); err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("hello", time.Now()); err != nil { + t.Fatal(err) + } + + word, err = st.Get("hello") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Set("yo", "friend", time.Time{}); err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("yo", time.Now().Add(100*time.Microsecond)); err != nil { + t.Fatal(err) + } + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("expected no words back but got: %q", word) + } + + // add more time to a key + if err := st.ExpireAt("yo", time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("expected key to still exist but got: %q", word) + } + + time.Sleep(time.Second) + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } +} diff --git a/kv/session.go b/kv/session.go index 3f4d3e49c7..c8adbf5f3e 100644 --- a/kv/session.go +++ b/kv/session.go @@ -79,8 +79,7 @@ func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Sessio } if err := sess.Expired(); err != nil { - // todo(leodido) > do we want to return session also if expired? - return sess, &influxdb.Error{ + return nil, &influxdb.Error{ Err: err, } } diff --git a/session.go b/session.go index 8cdd6ad2d5..78ba63988f 100644 --- a/session.go +++ b/session.go @@ -92,5 +92,7 @@ type SessionService interface { FindSession(ctx context.Context, key string) (*Session, error) ExpireSession(ctx context.Context, key string) error CreateSession(ctx context.Context, user string) (*Session, error) + // TODO: update RenewSession to take a ID instead of a session. + // By taking a session object it could be confused to update more things about the session RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error } diff --git a/session/errors.go b/session/errors.go new file mode 100644 index 0000000000..3b268e5da3 --- /dev/null +++ b/session/errors.go @@ -0,0 +1,14 @@ +package session + +import ( + "github.com/influxdata/influxdb/v2" +) + +var ( + // ErrUnauthorized when a session request is unauthorized + // usually due to password missmatch + ErrUnauthorized = &influxdb.Error{ + Code: influxdb.EUnauthorized, + Msg: "unauthorized access", + } +) diff --git a/session/http_server.go b/session/http_server.go new file mode 100644 index 0000000000..387c5f4e59 --- /dev/null +++ b/session/http_server.go @@ -0,0 +1,194 @@ +package session + +import ( + "context" + "net/http" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/influxdata/influxdb/v2" + kithttp "github.com/influxdata/influxdb/v2/kit/transport/http" + "go.uber.org/zap" +) + +const ( + prefixSignIn = "/api/v2/signin" + prefixSignOut = "/api/v2/signout" +) + +// SessionHandler represents an HTTP API handler for authorizations. +type SessionHandler struct { + chi.Router + api *kithttp.API + log *zap.Logger + + sessionSvc influxdb.SessionService + passSvc influxdb.PasswordsService + userSvc influxdb.UserService +} + +// NewSessionHandler returns a new instance of SessionHandler. +func NewSessionHandler(log *zap.Logger, sessionSvc influxdb.SessionService, userSvc influxdb.UserService, passwordsSvc influxdb.PasswordsService) *SessionHandler { + svr := &SessionHandler{ + api: kithttp.NewAPI(kithttp.WithLog(log)), + log: log, + + passSvc: passwordsSvc, + sessionSvc: sessionSvc, + userSvc: userSvc, + } + + return svr +} + +type resourceHandler struct { + prefix string + *SessionHandler +} + +// Prefix is necessary to mount the router as a resource handler +func (r resourceHandler) Prefix() string { return r.prefix } + +// SignInResourceHandler allows us to return 2 different rousource handler +// for the appropriate mounting location +func (h SessionHandler) SignInResourceHandler() *resourceHandler { + h.Router = chi.NewRouter() + h.Router.Use( + middleware.Recoverer, + middleware.RequestID, + middleware.RealIP, + ) + h.Router.Post("/", h.handleSignin) + return &resourceHandler{prefix: prefixSignIn, SessionHandler: &h} +} + +// SignOutResourceHandler allows us to return 2 different rousource handler +// for the appropriate mounting location +func (h SessionHandler) SignOutResourceHandler() *resourceHandler { + h.Router = chi.NewRouter() + h.Router.Use( + middleware.Recoverer, + middleware.RequestID, + middleware.RealIP, + ) + h.Router.Post("/", h.handleSignout) + return &resourceHandler{prefix: prefixSignOut, SessionHandler: &h} +} + +// handleSignin is the HTTP handler for the POST /signin route. +func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + req, decErr := decodeSigninRequest(ctx, r) + if decErr != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + u, err := h.userSvc.FindUser(ctx, influxdb.UserFilter{ + Name: &req.Username, + }) + if err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + if err := h.passSvc.ComparePassword(ctx, u.ID, req.Password); err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + s, e := h.sessionSvc.CreateSession(ctx, req.Username) + if e != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + encodeCookieSession(w, s) + w.WriteHeader(http.StatusNoContent) +} + +type signinRequest struct { + Username string + Password string +} + +func decodeSigninRequest(ctx context.Context, r *http.Request) (*signinRequest, *influxdb.Error) { + u, p, ok := r.BasicAuth() + if !ok { + return nil, &influxdb.Error{ + Code: influxdb.EInvalid, + Msg: "invalid basic auth", + } + } + + return &signinRequest{ + Username: u, + Password: p, + }, nil +} + +// handleSignout is the HTTP handler for the POST /signout route. +func (h *SessionHandler) handleSignout(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + req, err := decodeSignoutRequest(ctx, r) + if err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + if err := h.sessionSvc.ExpireSession(ctx, req.Key); err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +type signoutRequest struct { + Key string +} + +func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) { + key, err := decodeCookieSession(ctx, r) + if err != nil { + return nil, err + } + return &signoutRequest{ + Key: key, + }, nil +} + +const cookieSessionName = "session" + +func encodeCookieSession(w http.ResponseWriter, s *influxdb.Session) { + c := &http.Cookie{ + Name: cookieSessionName, + Value: s.Key, + } + + http.SetCookie(w, c) +} + +func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) { + c, err := r.Cookie(cookieSessionName) + if err != nil { + return "", &influxdb.Error{ + Code: influxdb.EInvalid, + Err: err, + } + } + return c.Value, nil +} + +// SetCookieSession adds a cookie for the session to an http request +func SetCookieSession(key string, r *http.Request) { + c := &http.Cookie{ + Name: cookieSessionName, + Value: key, + Secure: true, + } + + r.AddCookie(c) +} diff --git a/session/http_server_test.go b/session/http_server_test.go new file mode 100644 index 0000000000..dbc9a4d549 --- /dev/null +++ b/session/http_server_test.go @@ -0,0 +1,98 @@ +package session + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/mock" + "go.uber.org/zap/zaptest" +) + +func TestSessionHandler_handleSignin(t *testing.T) { + type fields struct { + PasswordsService influxdb.PasswordsService + SessionService influxdb.SessionService + } + type args struct { + user string + password string + } + type wants struct { + cookie string + code int + } + + tests := []struct { + name string + fields fields + args args + wants wants + }{ + { + name: "successful compare password", + fields: fields{ + SessionService: &mock.SessionService{ + CreateSessionFn: func(context.Context, string) (*influxdb.Session, error) { + return &influxdb.Session{ + ID: influxdb.ID(0), + Key: "abc123xyz", + CreatedAt: time.Date(2018, 9, 26, 0, 0, 0, 0, time.UTC), + ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), + UserID: influxdb.ID(1), + }, nil + }, + }, + PasswordsService: &mock.PasswordsService{ + ComparePasswordFn: func(context.Context, influxdb.ID, string) error { + return nil + }, + }, + }, + args: args{ + user: "user1", + password: "supersecret", + }, + wants: wants{ + cookie: "session=abc123xyz", + code: http.StatusNoContent, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userSVC := mock.NewUserService() + userSVC.FindUserFn = func(_ context.Context, f influxdb.UserFilter) (*influxdb.User, error) { + return &influxdb.User{ID: 1}, nil + } + h := NewSessionHandler(zaptest.NewLogger(t), tt.fields.SessionService, userSVC, tt.fields.PasswordsService) + + server := httptest.NewServer(h.SignInResourceHandler()) + client := server.Client() + + r, err := http.NewRequest("POST", server.URL, nil) + if err != nil { + t.Fatal(err) + } + r.SetBasicAuth(tt.args.user, tt.args.password) + + resp, err := client.Do(r) + if err != nil { + t.Fatal(err) + } + + if got, want := resp.StatusCode, tt.wants.code; got != want { + t.Errorf("bad status code: got %d want %d", got, want) + } + + cookie := resp.Header.Get("Set-Cookie") + if got, want := cookie, tt.wants.cookie; got != want { + t.Errorf("expected session cookie to be set: got %q want %q", got, want) + } + }) + } +} diff --git a/session/middleware_logging.go b/session/middleware_logging.go new file mode 100644 index 0000000000..1c9a3b4849 --- /dev/null +++ b/session/middleware_logging.go @@ -0,0 +1,81 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "go.uber.org/zap" +) + +// SessionLogger is a logger service middleware for sessions +type SessionLogger struct { + logger *zap.Logger + sessionService influxdb.SessionService +} + +var _ influxdb.SessionService = (*SessionLogger)(nil) + +// NewSessionLogger returns a logging service middleware for the User Service. +func NewSessionLogger(log *zap.Logger, s influxdb.SessionService) *SessionLogger { + return &SessionLogger{ + logger: log, + sessionService: s, + } +} + +// FindSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session find", zap.Error(err), dur) + return + } + l.logger.Debug("session find", dur) + }(time.Now()) + return l.sessionService.FindSession(ctx, key) + +} + +// ExpireSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) ExpireSession(ctx context.Context, key string) (err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session expire", zap.Error(err), dur) + return + } + l.logger.Debug("session expire", dur) + }(time.Now()) + return l.sessionService.ExpireSession(ctx, key) + +} + +// CreateSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session create", zap.Error(err), dur) + return + } + l.logger.Debug("session create", dur) + }(time.Now()) + return l.sessionService.CreateSession(ctx, user) + +} + +// RenewSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session renew", zap.Error(err), dur) + return + } + l.logger.Debug("session renew", dur) + }(time.Now()) + return l.sessionService.RenewSession(ctx, session, newExpiration) + +} diff --git a/session/middleware_metrics.go b/session/middleware_metrics.go new file mode 100644 index 0000000000..6e16d02799 --- /dev/null +++ b/session/middleware_metrics.go @@ -0,0 +1,56 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/metric" + "github.com/prometheus/client_golang/prometheus" +) + +// SessionMetrics is a metrics middleware system for the session service +type SessionMetrics struct { + // RED metrics + rec *metric.REDClient + + sessionSvc influxdb.SessionService +} + +var _ influxdb.SessionService = (*SessionMetrics)(nil) + +// NewSessionMetrics creates a new session metrics middleware +func NewSessionMetrics(reg prometheus.Registerer, s influxdb.SessionService) *SessionMetrics { + return &SessionMetrics{ + rec: metric.New(reg, "session"), + sessionSvc: s, + } +} + +// FindSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) { + rec := m.rec.Record("find_session") + session, err = m.sessionSvc.FindSession(ctx, key) + return session, rec(err) +} + +// ExpireSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) ExpireSession(ctx context.Context, key string) (err error) { + rec := m.rec.Record("expire_session") + err = m.sessionSvc.ExpireSession(ctx, key) + return rec(err) +} + +// CreateSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) { + rec := m.rec.Record("create_session") + s, err = m.sessionSvc.CreateSession(ctx, user) + return s, rec(err) +} + +// RenewSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) { + rec := m.rec.Record("renew_session") + err = m.sessionSvc.RenewSession(ctx, session, newExpiration) + return rec(err) +} diff --git a/session/service.go b/session/service.go new file mode 100644 index 0000000000..2df6bcd156 --- /dev/null +++ b/session/service.go @@ -0,0 +1,175 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/rand" + "github.com/influxdata/influxdb/v2/snowflake" +) + +// Service implements the influxdb.SessionService interface and +// handles communication between session and the necessary user and urm services +type Service struct { + store *Storage + userService influxdb.UserService + urmService influxdb.UserResourceMappingService + authService influxdb.AuthorizationService + sessionLength time.Duration + + idGen influxdb.IDGenerator + tokenGen influxdb.TokenGenerator + + disableAuthorizationsForMaxPermissions func(context.Context) bool +} + +// NewService creates a new session service +func NewService(store *Storage, userService influxdb.UserService, urmService influxdb.UserResourceMappingService, authSvc influxdb.AuthorizationService, sessionLength time.Duration) *Service { + if sessionLength <= 0 { + sessionLength = time.Hour + } + return &Service{ + store: store, + userService: userService, + urmService: urmService, + authService: authSvc, + sessionLength: sessionLength, + idGen: snowflake.NewIDGenerator(), + tokenGen: rand.NewTokenGenerator(64), + disableAuthorizationsForMaxPermissions: func(context.Context) bool { + return false + }, + } +} + +// WithMaxPermissionFunc sets the useAuthorizationsForMaxPermissions function +// which can trigger whether or not max permissions uses the users authorizations +// to derive maximum permissions. +func (s *Service) WithMaxPermissionFunc(fn func(context.Context) bool) { + s.disableAuthorizationsForMaxPermissions = fn +} + +// FindSession finds a session based on the session key +func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Session, error) { + session, err := s.store.FindSessionByKey(ctx, key) + if err != nil { + return nil, err + } + + // TODO: We want to be able to store permissions in the session + // but the contract provided by urm's doesn't give us enough information to quickly repopulate our + // session permissions on updates so we are required to pull the permissions every time we find the session. + permissions, err := s.getPermissionSet(ctx, session.UserID) + if err != nil { + return nil, err + } + + session.Permissions = permissions + return session, nil +} + +// ExpireSession removes a session from the system +func (s *Service) ExpireSession(ctx context.Context, key string) error { + session, err := s.store.FindSessionByKey(ctx, key) + if err != nil { + return err + } + return s.store.DeleteSession(ctx, session.ID) +} + +// CreateSession +func (s *Service) CreateSession(ctx context.Context, user string) (*influxdb.Session, error) { + u, err := s.userService.FindUser(ctx, influxdb.UserFilter{ + Name: &user, + }) + if err != nil { + return nil, err + } + + token, err := s.tokenGen.Token() + if err != nil { + return nil, err + } + + // for now we are not storing the permissions because we need to pull them every time we find + // so we might as well keep the session stored small + now := time.Now() + session := &influxdb.Session{ + ID: s.idGen.ID(), + Key: token, + CreatedAt: now, + ExpiresAt: now.Add(s.sessionLength), + UserID: u.ID, + } + + return session, s.store.CreateSession(ctx, session) +} + +// RenewSession update the sessions expiration time +func (s *Service) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) error { + if session == nil { + return &influxdb.Error{ + Msg: "session is nil", + } + } + return s.store.RefreshSession(ctx, session.ID, newExpiration) +} + +func (s *Service) getPermissionSet(ctx context.Context, uid influxdb.ID) ([]influxdb.Permission, error) { + + mappings, _, err := s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Limit: 100}) + if err != nil { + return nil, err + } + + permissions, err := permissionFromMapping(mappings) + if err != nil { + return nil, err + } + + if len(mappings) == 100 { + // if we got 100 mappings we probably need to pull more pages + // account for paginated results + for i := len(mappings); len(mappings) > 0; i += len(mappings) { + mappings, _, err = s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Offset: i, Limit: 100}) + if err != nil { + return nil, err + } + pms, err := permissionFromMapping(mappings) + if err != nil { + return nil, err + } + permissions = append(permissions, pms...) + } + } + + if !s.disableAuthorizationsForMaxPermissions(ctx) { + as, _, err := s.authService.FindAuthorizations(ctx, influxdb.AuthorizationFilter{UserID: &uid}) + if err != nil { + return nil, err + } + for _, a := range as { + permissions = append(permissions, a.Permissions...) + } + } + + permissions = append(permissions, influxdb.MePermissions(uid)...) + return permissions, nil +} + +func permissionFromMapping(mappings []*influxdb.UserResourceMapping) ([]influxdb.Permission, error) { + ps := make([]influxdb.Permission, 0, len(mappings)) + for _, m := range mappings { + p, err := m.ToPermissions() + if err != nil { + return nil, &influxdb.Error{ + Err: err, + } + } + + ps = append(ps, p...) + } + + return ps, nil +} diff --git a/session/service_test.go b/session/service_test.go new file mode 100644 index 0000000000..5c233aa11a --- /dev/null +++ b/session/service_test.go @@ -0,0 +1,49 @@ +package session + +import ( + "context" + "testing" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/inmem" + "github.com/influxdata/influxdb/v2/mock" + "github.com/influxdata/influxdb/v2/tenant" + influxdbtesting "github.com/influxdata/influxdb/v2/testing" +) + +func TestSessionService(t *testing.T) { + influxdbtesting.SessionService(initSessionService, t) +} + +func initSessionService(f influxdbtesting.SessionFields, t *testing.T) (influxdb.SessionService, string, func()) { + ss := NewStorage(inmem.NewSessionStore()) + ts, _ := tenant.NewStore(inmem.NewKVStore()) + ten := tenant.NewService(ts) + svc := NewService(ss, ten, ten, &mock.AuthorizationService{ + FindAuthorizationsFn: func(context.Context, influxdb.AuthorizationFilter, ...influxdb.FindOptions) ([]*influxdb.Authorization, int, error) { + return []*influxdb.Authorization{}, 0, nil + }, + }, time.Minute) + + if f.IDGenerator != nil { + svc.idGen = f.IDGenerator + } + + if f.TokenGenerator != nil { + svc.tokenGen = f.TokenGenerator + } + + ctx := context.Background() + for _, u := range f.Users { + if err := ten.CreateUser(ctx, u); err != nil { + t.Fatalf("failed to populate users") + } + } + for _, s := range f.Sessions { + if err := ss.CreateSession(ctx, s); err != nil { + t.Fatalf("failed to populate sessions") + } + } + return svc, "session", func() {} +} diff --git a/session/storage.go b/session/storage.go new file mode 100644 index 0000000000..e84e4f01d0 --- /dev/null +++ b/session/storage.go @@ -0,0 +1,136 @@ +package session + +import ( + "context" + "encoding/json" + "time" + + "github.com/influxdata/influxdb/v2" +) + +type Store interface { + Set(key, val string, expireAt time.Time) error + Get(key string) (string, error) + Delete(key string) error + ExpireAt(key string, expireAt time.Time) error +} + +var storePrefix = "sessionsv1/" +var storeIndex = "sessionsindexv1/" + +// Storage is a store translation layer between the data storage unit and the +// service layer. +type Storage struct { + store Store +} + +// NewStorage creates a new storage system +func NewStorage(s Store) *Storage { + return &Storage{s} +} + +// FindSessionByKey use a given key to retrieve the stored session +func (s *Storage) FindSessionByKey(ctx context.Context, key string) (*influxdb.Session, error) { + val, err := s.store.Get(sessionIndexKey(key)) + if err != nil { + return nil, err + } + + if val == "" { + return nil, &influxdb.Error{ + Code: influxdb.ENotFound, + Msg: influxdb.ErrSessionNotFound, + } + } + + id, err := influxdb.IDFromString(val) + if err != nil { + return nil, err + } + return s.FindSessionByID(ctx, *id) +} + +// FindSessionByID use a provided id to retrieve the stored session +func (s *Storage) FindSessionByID(ctx context.Context, id influxdb.ID) (*influxdb.Session, error) { + val, err := s.store.Get(storePrefix + id.String()) + if err != nil { + return nil, err + } + if val == "" { + return nil, &influxdb.Error{ + Code: influxdb.ENotFound, + Msg: influxdb.ErrSessionNotFound, + } + } + + session := &influxdb.Session{} + return session, json.Unmarshal([]byte(val), session) +} + +// CreateSession creates a new session +func (s *Storage) CreateSession(ctx context.Context, session *influxdb.Session) error { + // create session + sessionBytes, err := json.Marshal(session) + if err != nil { + return err + } + + // use a minute time just so the session will expire if we fail to set the expiration later + sessionID := sessionID(session.ID) + if err := s.store.Set(sessionID, string(sessionBytes), session.ExpiresAt); err != nil { + return err + } + + // create index + indexKey := sessionIndexKey(session.Key) + if err := s.store.Set(indexKey, session.ID.String(), session.ExpiresAt); err != nil { + return err + } + + return nil +} + +// RefreshSession updates the expiration time of a session. +func (s *Storage) RefreshSession(ctx context.Context, id influxdb.ID, expireAt time.Time) error { + session, err := s.FindSessionByID(ctx, id) + if err != nil { + return err + } + + if expireAt.Before(session.ExpiresAt) { + // no need to recreate the session if we aren't extending the expiration + return nil + } + + session.ExpiresAt = expireAt + return s.CreateSession(ctx, session) +} + +// DeleteSession removes the session and index from storage +func (s *Storage) DeleteSession(ctx context.Context, id influxdb.ID) error { + session, err := s.FindSessionByID(ctx, id) + if err != nil { + return err + } + if session == nil { + return nil + } + + if err := s.store.Delete(sessionID(session.ID)); err != nil { + return err + } + + if err := s.store.Delete(sessionIndexKey(session.Key)); err != nil { + return err + } + + return nil +} + +func sessionID(id influxdb.ID) string { + return storePrefix + id.String() +} + +func sessionIndexKey(key string) string { + return storeIndex + key +} diff --git a/session/storage_test.go b/session/storage_test.go new file mode 100644 index 0000000000..3e208c65e3 --- /dev/null +++ b/session/storage_test.go @@ -0,0 +1,120 @@ +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/inmem" + "github.com/influxdata/influxdb/v2/session" +) + +func TestSessionStore(t *testing.T) { + driver := func() session.Store { + return inmem.NewSessionStore() + } + + expected := &influxdb.Session{ + ID: 1, + Key: "2", + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + } + + simpleSetup := func(t *testing.T, store *session.Storage) { + err := store.CreateSession( + context.Background(), + expected, + ) + if err != nil { + t.Fatal(err) + } + } + + st := []struct { + name string + setup func(*testing.T, *session.Storage) + update func(*testing.T, *session.Storage) + results func(*testing.T, *session.Storage) + }{ + { + name: "create", + setup: simpleSetup, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + }, + }, + { + name: "get", + setup: simpleSetup, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + + session, err = store.FindSessionByKey(context.Background(), "2") + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + }, + }, + { + name: "delete", + setup: simpleSetup, + update: func(t *testing.T, store *session.Storage) { + err := store.DeleteSession(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + }, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err == nil { + t.Fatal("expected error on deleted session but got none") + } + + if session != nil { + t.Fatal("got a session when none should have existed") + } + }, + }, + } + for _, testScenario := range st { + t.Run(testScenario.name, func(t *testing.T) { + ss := session.NewStorage(driver()) + + // setup + if testScenario.setup != nil { + testScenario.setup(t, ss) + } + + // update + if testScenario.update != nil { + testScenario.update(t, ss) + } + + // results + if testScenario.results != nil { + testScenario.results(t, ss) + } + }) + } +} diff --git a/task/backend/analytical_storage.go b/task/backend/analytical_storage.go index 38b00fe1b9..9820f40d67 100644 --- a/task/backend/analytical_storage.go +++ b/task/backend/analytical_storage.go @@ -326,7 +326,7 @@ func (as *AnalyticalStorage) RetryRun(ctx context.Context, taskID, runID influxd return run, err } - // try finding the run (in our system or underlieing) + // try finding the run (in our system or underlying) run, err = as.FindRunByID(ctx, taskID, runID) if err != nil { return run, err diff --git a/tenant/middleware_onboarding_auth.go b/tenant/middleware_onboarding_auth.go index daf5304b7c..2861874b50 100644 --- a/tenant/middleware_onboarding_auth.go +++ b/tenant/middleware_onboarding_auth.go @@ -24,12 +24,12 @@ func NewAuthedOnboardSvc(s influxdb.OnboardingService) *AuthedOnboardSvc { } } -// IsOnboarding pass through. this is handled by the underlieing service layer +// IsOnboarding pass through. this is handled by the underlying service layer func (s *AuthedOnboardSvc) IsOnboarding(ctx context.Context) (bool, error) { return s.s.IsOnboarding(ctx) } -// OnboardInitialUser pass through. this is handled by the underlieing service layer +// OnboardInitialUser pass through. this is handled by the underlying service layer func (s *AuthedOnboardSvc) OnboardInitialUser(ctx context.Context, req *influxdb.OnboardingRequest) (*influxdb.OnboardingResults, error) { return s.s.OnboardInitialUser(ctx, req) } diff --git a/tenant/storage.go b/tenant/storage.go index b2a0e19f82..611d6e3d45 100644 --- a/tenant/storage.go +++ b/tenant/storage.go @@ -68,6 +68,11 @@ func (s *Store) setup() error { if _, err := tx.Bucket(urmBucket); err != nil { return err } + + if _, err := tx.Bucket(urmByUserIndexBucket); err != nil { + return err + } + if _, err := tx.Bucket(organizationBucket); err != nil { return err } diff --git a/testing/session.go b/testing/session.go index e926caecc9..cb794905cd 100644 --- a/testing/session.go +++ b/testing/session.go @@ -9,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/influxdata/influxdb/v2" platform "github.com/influxdata/influxdb/v2" "github.com/influxdata/influxdb/v2/mock" ) @@ -274,12 +275,12 @@ func ExpireSession( diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t) session, err := s.FindSession(ctx, tt.args.key) - if err == nil { + if err.Error() != influxdb.ErrSessionExpired && err.Error() != influxdb.ErrSessionNotFound { t.Errorf("expected session to be expired got %v", err) } - if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" { - t.Errorf("session is different -got/+want\ndiff %s", diff) + if session != nil { + t.Errorf("expected a nil session but got: %v", session) } }) }