diff --git a/inmem/session_store.go b/inmem/session_store.go new file mode 100644 index 0000000000..ed7167d7af --- /dev/null +++ b/inmem/session_store.go @@ -0,0 +1,86 @@ +package inmem + +import ( + "errors" + "sync" + "time" +) + +type SessionStore struct { + data map[string]string + + timers map[string]*time.Timer + + mu sync.RWMutex +} + +func NewSessionStore() *SessionStore { + return &SessionStore{ + data: map[string]string{}, + timers: map[string]*time.Timer{}, + } +} + +func (s *SessionStore) Set(key, val string, expireAt time.Time) error { + if !expireAt.IsZero() && expireAt.Before(time.Now()) { + // key is already expired. no problem + return nil + } + + s.mu.Lock() + s.data[key] = val + s.mu.Unlock() + + if !expireAt.IsZero() { + return s.ExpireAt(key, expireAt) + } + return nil +} + +func (s *SessionStore) Get(key string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.data[key], nil +} + +func (s *SessionStore) Delete(key string) error { + s.mu.Lock() + defer s.mu.Unlock() + timer := s.timers[key] + if timer != nil { + timer.Stop() + } + + delete(s.data, key) + delete(s.timers, key) + return nil +} + +func (s *SessionStore) ExpireAt(key string, expireAt time.Time) error { + s.mu.Lock() + + existingTimer, ok := s.timers[key] + if ok { + if !existingTimer.Stop() { + return errors.New("session has expired") + } + + } + + duration := time.Until(expireAt) + if duration <= 0 { + s.mu.Unlock() + s.Delete(key) + return nil + } + s.timers[key] = time.AfterFunc(time.Until(expireAt), s.timerExpireFunc(key)) + s.mu.Unlock() + return nil +} + +func (s *SessionStore) timerExpireFunc(key string) func() { + return func() { + s.Delete(key) + } +} diff --git a/inmem/session_store_test.go b/inmem/session_store_test.go new file mode 100644 index 0000000000..d4fb46e605 --- /dev/null +++ b/inmem/session_store_test.go @@ -0,0 +1,162 @@ +package inmem_test + +import ( + "testing" + "time" + + "github.com/influxdata/influxdb/v2/inmem" +) + +func TestSessionSet(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + err = st.Set("hi", "enemy", time.Time{}) + if err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "enemy" { + t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word) + } +} +func TestSessionGet(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Now().Add(time.Second)) + if err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word) + } + + time.Sleep(time.Second * 2) + + word, err = st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } +} +func TestSessionDelete(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + if err := st.Delete("hi"); err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Delete("hi"); err != nil { + t.Fatal(err) + } +} +func TestSessionExpireAt(t *testing.T) { + st := inmem.NewSessionStore() + err := st.Set("hi", "friend", time.Time{}) + if err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("hi", time.Now().Add(-20)); err != nil { + t.Fatal(err) + } + + word, err := st.Get("hi") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Set("hello", "friend", time.Time{}); err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("hello", time.Now()); err != nil { + t.Fatal(err) + } + + word, err = st.Get("hello") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } + + if err := st.Set("yo", "friend", time.Time{}); err != nil { + t.Fatal(err) + } + + if err := st.ExpireAt("yo", time.Now().Add(100*time.Microsecond)); err != nil { + t.Fatal(err) + } + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("expected no words back but got: %q", word) + } + + // add more time to a key + if err := st.ExpireAt("yo", time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "friend" { + t.Fatalf("expected key to still exist but got: %q", word) + } + + time.Sleep(time.Second) + + word, err = st.Get("yo") + if err != nil { + t.Fatal(err) + } + + if word != "" { + t.Fatalf("expected no words back but got: %s", word) + } +} diff --git a/kv/session.go b/kv/session.go index 3f4d3e49c7..c8adbf5f3e 100644 --- a/kv/session.go +++ b/kv/session.go @@ -79,8 +79,7 @@ func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Sessio } if err := sess.Expired(); err != nil { - // todo(leodido) > do we want to return session also if expired? - return sess, &influxdb.Error{ + return nil, &influxdb.Error{ Err: err, } } diff --git a/session.go b/session.go index 8cdd6ad2d5..78ba63988f 100644 --- a/session.go +++ b/session.go @@ -92,5 +92,7 @@ type SessionService interface { FindSession(ctx context.Context, key string) (*Session, error) ExpireSession(ctx context.Context, key string) error CreateSession(ctx context.Context, user string) (*Session, error) + // TODO: update RenewSession to take a ID instead of a session. + // By taking a session object it could be confused to update more things about the session RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error } diff --git a/session/errors.go b/session/errors.go new file mode 100644 index 0000000000..3b268e5da3 --- /dev/null +++ b/session/errors.go @@ -0,0 +1,14 @@ +package session + +import ( + "github.com/influxdata/influxdb/v2" +) + +var ( + // ErrUnauthorized when a session request is unauthorized + // usually due to password missmatch + ErrUnauthorized = &influxdb.Error{ + Code: influxdb.EUnauthorized, + Msg: "unauthorized access", + } +) diff --git a/session/http_server.go b/session/http_server.go new file mode 100644 index 0000000000..387c5f4e59 --- /dev/null +++ b/session/http_server.go @@ -0,0 +1,194 @@ +package session + +import ( + "context" + "net/http" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/influxdata/influxdb/v2" + kithttp "github.com/influxdata/influxdb/v2/kit/transport/http" + "go.uber.org/zap" +) + +const ( + prefixSignIn = "/api/v2/signin" + prefixSignOut = "/api/v2/signout" +) + +// SessionHandler represents an HTTP API handler for authorizations. +type SessionHandler struct { + chi.Router + api *kithttp.API + log *zap.Logger + + sessionSvc influxdb.SessionService + passSvc influxdb.PasswordsService + userSvc influxdb.UserService +} + +// NewSessionHandler returns a new instance of SessionHandler. +func NewSessionHandler(log *zap.Logger, sessionSvc influxdb.SessionService, userSvc influxdb.UserService, passwordsSvc influxdb.PasswordsService) *SessionHandler { + svr := &SessionHandler{ + api: kithttp.NewAPI(kithttp.WithLog(log)), + log: log, + + passSvc: passwordsSvc, + sessionSvc: sessionSvc, + userSvc: userSvc, + } + + return svr +} + +type resourceHandler struct { + prefix string + *SessionHandler +} + +// Prefix is necessary to mount the router as a resource handler +func (r resourceHandler) Prefix() string { return r.prefix } + +// SignInResourceHandler allows us to return 2 different rousource handler +// for the appropriate mounting location +func (h SessionHandler) SignInResourceHandler() *resourceHandler { + h.Router = chi.NewRouter() + h.Router.Use( + middleware.Recoverer, + middleware.RequestID, + middleware.RealIP, + ) + h.Router.Post("/", h.handleSignin) + return &resourceHandler{prefix: prefixSignIn, SessionHandler: &h} +} + +// SignOutResourceHandler allows us to return 2 different rousource handler +// for the appropriate mounting location +func (h SessionHandler) SignOutResourceHandler() *resourceHandler { + h.Router = chi.NewRouter() + h.Router.Use( + middleware.Recoverer, + middleware.RequestID, + middleware.RealIP, + ) + h.Router.Post("/", h.handleSignout) + return &resourceHandler{prefix: prefixSignOut, SessionHandler: &h} +} + +// handleSignin is the HTTP handler for the POST /signin route. +func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + req, decErr := decodeSigninRequest(ctx, r) + if decErr != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + u, err := h.userSvc.FindUser(ctx, influxdb.UserFilter{ + Name: &req.Username, + }) + if err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + if err := h.passSvc.ComparePassword(ctx, u.ID, req.Password); err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + s, e := h.sessionSvc.CreateSession(ctx, req.Username) + if e != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + encodeCookieSession(w, s) + w.WriteHeader(http.StatusNoContent) +} + +type signinRequest struct { + Username string + Password string +} + +func decodeSigninRequest(ctx context.Context, r *http.Request) (*signinRequest, *influxdb.Error) { + u, p, ok := r.BasicAuth() + if !ok { + return nil, &influxdb.Error{ + Code: influxdb.EInvalid, + Msg: "invalid basic auth", + } + } + + return &signinRequest{ + Username: u, + Password: p, + }, nil +} + +// handleSignout is the HTTP handler for the POST /signout route. +func (h *SessionHandler) handleSignout(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + req, err := decodeSignoutRequest(ctx, r) + if err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + if err := h.sessionSvc.ExpireSession(ctx, req.Key); err != nil { + h.api.Err(w, ErrUnauthorized) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +type signoutRequest struct { + Key string +} + +func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) { + key, err := decodeCookieSession(ctx, r) + if err != nil { + return nil, err + } + return &signoutRequest{ + Key: key, + }, nil +} + +const cookieSessionName = "session" + +func encodeCookieSession(w http.ResponseWriter, s *influxdb.Session) { + c := &http.Cookie{ + Name: cookieSessionName, + Value: s.Key, + } + + http.SetCookie(w, c) +} + +func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) { + c, err := r.Cookie(cookieSessionName) + if err != nil { + return "", &influxdb.Error{ + Code: influxdb.EInvalid, + Err: err, + } + } + return c.Value, nil +} + +// SetCookieSession adds a cookie for the session to an http request +func SetCookieSession(key string, r *http.Request) { + c := &http.Cookie{ + Name: cookieSessionName, + Value: key, + Secure: true, + } + + r.AddCookie(c) +} diff --git a/session/http_server_test.go b/session/http_server_test.go new file mode 100644 index 0000000000..dbc9a4d549 --- /dev/null +++ b/session/http_server_test.go @@ -0,0 +1,98 @@ +package session + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/mock" + "go.uber.org/zap/zaptest" +) + +func TestSessionHandler_handleSignin(t *testing.T) { + type fields struct { + PasswordsService influxdb.PasswordsService + SessionService influxdb.SessionService + } + type args struct { + user string + password string + } + type wants struct { + cookie string + code int + } + + tests := []struct { + name string + fields fields + args args + wants wants + }{ + { + name: "successful compare password", + fields: fields{ + SessionService: &mock.SessionService{ + CreateSessionFn: func(context.Context, string) (*influxdb.Session, error) { + return &influxdb.Session{ + ID: influxdb.ID(0), + Key: "abc123xyz", + CreatedAt: time.Date(2018, 9, 26, 0, 0, 0, 0, time.UTC), + ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), + UserID: influxdb.ID(1), + }, nil + }, + }, + PasswordsService: &mock.PasswordsService{ + ComparePasswordFn: func(context.Context, influxdb.ID, string) error { + return nil + }, + }, + }, + args: args{ + user: "user1", + password: "supersecret", + }, + wants: wants{ + cookie: "session=abc123xyz", + code: http.StatusNoContent, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userSVC := mock.NewUserService() + userSVC.FindUserFn = func(_ context.Context, f influxdb.UserFilter) (*influxdb.User, error) { + return &influxdb.User{ID: 1}, nil + } + h := NewSessionHandler(zaptest.NewLogger(t), tt.fields.SessionService, userSVC, tt.fields.PasswordsService) + + server := httptest.NewServer(h.SignInResourceHandler()) + client := server.Client() + + r, err := http.NewRequest("POST", server.URL, nil) + if err != nil { + t.Fatal(err) + } + r.SetBasicAuth(tt.args.user, tt.args.password) + + resp, err := client.Do(r) + if err != nil { + t.Fatal(err) + } + + if got, want := resp.StatusCode, tt.wants.code; got != want { + t.Errorf("bad status code: got %d want %d", got, want) + } + + cookie := resp.Header.Get("Set-Cookie") + if got, want := cookie, tt.wants.cookie; got != want { + t.Errorf("expected session cookie to be set: got %q want %q", got, want) + } + }) + } +} diff --git a/session/middleware_logging.go b/session/middleware_logging.go new file mode 100644 index 0000000000..1c9a3b4849 --- /dev/null +++ b/session/middleware_logging.go @@ -0,0 +1,81 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "go.uber.org/zap" +) + +// SessionLogger is a logger service middleware for sessions +type SessionLogger struct { + logger *zap.Logger + sessionService influxdb.SessionService +} + +var _ influxdb.SessionService = (*SessionLogger)(nil) + +// NewSessionLogger returns a logging service middleware for the User Service. +func NewSessionLogger(log *zap.Logger, s influxdb.SessionService) *SessionLogger { + return &SessionLogger{ + logger: log, + sessionService: s, + } +} + +// FindSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session find", zap.Error(err), dur) + return + } + l.logger.Debug("session find", dur) + }(time.Now()) + return l.sessionService.FindSession(ctx, key) + +} + +// ExpireSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) ExpireSession(ctx context.Context, key string) (err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session expire", zap.Error(err), dur) + return + } + l.logger.Debug("session expire", dur) + }(time.Now()) + return l.sessionService.ExpireSession(ctx, key) + +} + +// CreateSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session create", zap.Error(err), dur) + return + } + l.logger.Debug("session create", dur) + }(time.Now()) + return l.sessionService.CreateSession(ctx, user) + +} + +// RenewSession calls the underlying session service and logs the results of the request +func (l *SessionLogger) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) { + defer func(start time.Time) { + dur := zap.Duration("took", time.Since(start)) + if err != nil { + l.logger.Error("failed to session renew", zap.Error(err), dur) + return + } + l.logger.Debug("session renew", dur) + }(time.Now()) + return l.sessionService.RenewSession(ctx, session, newExpiration) + +} diff --git a/session/middleware_metrics.go b/session/middleware_metrics.go new file mode 100644 index 0000000000..6e16d02799 --- /dev/null +++ b/session/middleware_metrics.go @@ -0,0 +1,56 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/metric" + "github.com/prometheus/client_golang/prometheus" +) + +// SessionMetrics is a metrics middleware system for the session service +type SessionMetrics struct { + // RED metrics + rec *metric.REDClient + + sessionSvc influxdb.SessionService +} + +var _ influxdb.SessionService = (*SessionMetrics)(nil) + +// NewSessionMetrics creates a new session metrics middleware +func NewSessionMetrics(reg prometheus.Registerer, s influxdb.SessionService) *SessionMetrics { + return &SessionMetrics{ + rec: metric.New(reg, "session"), + sessionSvc: s, + } +} + +// FindSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) { + rec := m.rec.Record("find_session") + session, err = m.sessionSvc.FindSession(ctx, key) + return session, rec(err) +} + +// ExpireSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) ExpireSession(ctx context.Context, key string) (err error) { + rec := m.rec.Record("expire_session") + err = m.sessionSvc.ExpireSession(ctx, key) + return rec(err) +} + +// CreateSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) { + rec := m.rec.Record("create_session") + s, err = m.sessionSvc.CreateSession(ctx, user) + return s, rec(err) +} + +// RenewSession calls the underlying session service and tracks RED metrics for the call +func (m *SessionMetrics) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) { + rec := m.rec.Record("renew_session") + err = m.sessionSvc.RenewSession(ctx, session, newExpiration) + return rec(err) +} diff --git a/session/service.go b/session/service.go new file mode 100644 index 0000000000..2df6bcd156 --- /dev/null +++ b/session/service.go @@ -0,0 +1,175 @@ +package session + +import ( + "context" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/rand" + "github.com/influxdata/influxdb/v2/snowflake" +) + +// Service implements the influxdb.SessionService interface and +// handles communication between session and the necessary user and urm services +type Service struct { + store *Storage + userService influxdb.UserService + urmService influxdb.UserResourceMappingService + authService influxdb.AuthorizationService + sessionLength time.Duration + + idGen influxdb.IDGenerator + tokenGen influxdb.TokenGenerator + + disableAuthorizationsForMaxPermissions func(context.Context) bool +} + +// NewService creates a new session service +func NewService(store *Storage, userService influxdb.UserService, urmService influxdb.UserResourceMappingService, authSvc influxdb.AuthorizationService, sessionLength time.Duration) *Service { + if sessionLength <= 0 { + sessionLength = time.Hour + } + return &Service{ + store: store, + userService: userService, + urmService: urmService, + authService: authSvc, + sessionLength: sessionLength, + idGen: snowflake.NewIDGenerator(), + tokenGen: rand.NewTokenGenerator(64), + disableAuthorizationsForMaxPermissions: func(context.Context) bool { + return false + }, + } +} + +// WithMaxPermissionFunc sets the useAuthorizationsForMaxPermissions function +// which can trigger whether or not max permissions uses the users authorizations +// to derive maximum permissions. +func (s *Service) WithMaxPermissionFunc(fn func(context.Context) bool) { + s.disableAuthorizationsForMaxPermissions = fn +} + +// FindSession finds a session based on the session key +func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Session, error) { + session, err := s.store.FindSessionByKey(ctx, key) + if err != nil { + return nil, err + } + + // TODO: We want to be able to store permissions in the session + // but the contract provided by urm's doesn't give us enough information to quickly repopulate our + // session permissions on updates so we are required to pull the permissions every time we find the session. + permissions, err := s.getPermissionSet(ctx, session.UserID) + if err != nil { + return nil, err + } + + session.Permissions = permissions + return session, nil +} + +// ExpireSession removes a session from the system +func (s *Service) ExpireSession(ctx context.Context, key string) error { + session, err := s.store.FindSessionByKey(ctx, key) + if err != nil { + return err + } + return s.store.DeleteSession(ctx, session.ID) +} + +// CreateSession +func (s *Service) CreateSession(ctx context.Context, user string) (*influxdb.Session, error) { + u, err := s.userService.FindUser(ctx, influxdb.UserFilter{ + Name: &user, + }) + if err != nil { + return nil, err + } + + token, err := s.tokenGen.Token() + if err != nil { + return nil, err + } + + // for now we are not storing the permissions because we need to pull them every time we find + // so we might as well keep the session stored small + now := time.Now() + session := &influxdb.Session{ + ID: s.idGen.ID(), + Key: token, + CreatedAt: now, + ExpiresAt: now.Add(s.sessionLength), + UserID: u.ID, + } + + return session, s.store.CreateSession(ctx, session) +} + +// RenewSession update the sessions expiration time +func (s *Service) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) error { + if session == nil { + return &influxdb.Error{ + Msg: "session is nil", + } + } + return s.store.RefreshSession(ctx, session.ID, newExpiration) +} + +func (s *Service) getPermissionSet(ctx context.Context, uid influxdb.ID) ([]influxdb.Permission, error) { + + mappings, _, err := s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Limit: 100}) + if err != nil { + return nil, err + } + + permissions, err := permissionFromMapping(mappings) + if err != nil { + return nil, err + } + + if len(mappings) == 100 { + // if we got 100 mappings we probably need to pull more pages + // account for paginated results + for i := len(mappings); len(mappings) > 0; i += len(mappings) { + mappings, _, err = s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Offset: i, Limit: 100}) + if err != nil { + return nil, err + } + pms, err := permissionFromMapping(mappings) + if err != nil { + return nil, err + } + permissions = append(permissions, pms...) + } + } + + if !s.disableAuthorizationsForMaxPermissions(ctx) { + as, _, err := s.authService.FindAuthorizations(ctx, influxdb.AuthorizationFilter{UserID: &uid}) + if err != nil { + return nil, err + } + for _, a := range as { + permissions = append(permissions, a.Permissions...) + } + } + + permissions = append(permissions, influxdb.MePermissions(uid)...) + return permissions, nil +} + +func permissionFromMapping(mappings []*influxdb.UserResourceMapping) ([]influxdb.Permission, error) { + ps := make([]influxdb.Permission, 0, len(mappings)) + for _, m := range mappings { + p, err := m.ToPermissions() + if err != nil { + return nil, &influxdb.Error{ + Err: err, + } + } + + ps = append(ps, p...) + } + + return ps, nil +} diff --git a/session/service_test.go b/session/service_test.go new file mode 100644 index 0000000000..5c233aa11a --- /dev/null +++ b/session/service_test.go @@ -0,0 +1,49 @@ +package session + +import ( + "context" + "testing" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/inmem" + "github.com/influxdata/influxdb/v2/mock" + "github.com/influxdata/influxdb/v2/tenant" + influxdbtesting "github.com/influxdata/influxdb/v2/testing" +) + +func TestSessionService(t *testing.T) { + influxdbtesting.SessionService(initSessionService, t) +} + +func initSessionService(f influxdbtesting.SessionFields, t *testing.T) (influxdb.SessionService, string, func()) { + ss := NewStorage(inmem.NewSessionStore()) + ts, _ := tenant.NewStore(inmem.NewKVStore()) + ten := tenant.NewService(ts) + svc := NewService(ss, ten, ten, &mock.AuthorizationService{ + FindAuthorizationsFn: func(context.Context, influxdb.AuthorizationFilter, ...influxdb.FindOptions) ([]*influxdb.Authorization, int, error) { + return []*influxdb.Authorization{}, 0, nil + }, + }, time.Minute) + + if f.IDGenerator != nil { + svc.idGen = f.IDGenerator + } + + if f.TokenGenerator != nil { + svc.tokenGen = f.TokenGenerator + } + + ctx := context.Background() + for _, u := range f.Users { + if err := ten.CreateUser(ctx, u); err != nil { + t.Fatalf("failed to populate users") + } + } + for _, s := range f.Sessions { + if err := ss.CreateSession(ctx, s); err != nil { + t.Fatalf("failed to populate sessions") + } + } + return svc, "session", func() {} +} diff --git a/session/storage.go b/session/storage.go new file mode 100644 index 0000000000..e84e4f01d0 --- /dev/null +++ b/session/storage.go @@ -0,0 +1,136 @@ +package session + +import ( + "context" + "encoding/json" + "time" + + "github.com/influxdata/influxdb/v2" +) + +type Store interface { + Set(key, val string, expireAt time.Time) error + Get(key string) (string, error) + Delete(key string) error + ExpireAt(key string, expireAt time.Time) error +} + +var storePrefix = "sessionsv1/" +var storeIndex = "sessionsindexv1/" + +// Storage is a store translation layer between the data storage unit and the +// service layer. +type Storage struct { + store Store +} + +// NewStorage creates a new storage system +func NewStorage(s Store) *Storage { + return &Storage{s} +} + +// FindSessionByKey use a given key to retrieve the stored session +func (s *Storage) FindSessionByKey(ctx context.Context, key string) (*influxdb.Session, error) { + val, err := s.store.Get(sessionIndexKey(key)) + if err != nil { + return nil, err + } + + if val == "" { + return nil, &influxdb.Error{ + Code: influxdb.ENotFound, + Msg: influxdb.ErrSessionNotFound, + } + } + + id, err := influxdb.IDFromString(val) + if err != nil { + return nil, err + } + return s.FindSessionByID(ctx, *id) +} + +// FindSessionByID use a provided id to retrieve the stored session +func (s *Storage) FindSessionByID(ctx context.Context, id influxdb.ID) (*influxdb.Session, error) { + val, err := s.store.Get(storePrefix + id.String()) + if err != nil { + return nil, err + } + if val == "" { + return nil, &influxdb.Error{ + Code: influxdb.ENotFound, + Msg: influxdb.ErrSessionNotFound, + } + } + + session := &influxdb.Session{} + return session, json.Unmarshal([]byte(val), session) +} + +// CreateSession creates a new session +func (s *Storage) CreateSession(ctx context.Context, session *influxdb.Session) error { + // create session + sessionBytes, err := json.Marshal(session) + if err != nil { + return err + } + + // use a minute time just so the session will expire if we fail to set the expiration later + sessionID := sessionID(session.ID) + if err := s.store.Set(sessionID, string(sessionBytes), session.ExpiresAt); err != nil { + return err + } + + // create index + indexKey := sessionIndexKey(session.Key) + if err := s.store.Set(indexKey, session.ID.String(), session.ExpiresAt); err != nil { + return err + } + + return nil +} + +// RefreshSession updates the expiration time of a session. +func (s *Storage) RefreshSession(ctx context.Context, id influxdb.ID, expireAt time.Time) error { + session, err := s.FindSessionByID(ctx, id) + if err != nil { + return err + } + + if expireAt.Before(session.ExpiresAt) { + // no need to recreate the session if we aren't extending the expiration + return nil + } + + session.ExpiresAt = expireAt + return s.CreateSession(ctx, session) +} + +// DeleteSession removes the session and index from storage +func (s *Storage) DeleteSession(ctx context.Context, id influxdb.ID) error { + session, err := s.FindSessionByID(ctx, id) + if err != nil { + return err + } + if session == nil { + return nil + } + + if err := s.store.Delete(sessionID(session.ID)); err != nil { + return err + } + + if err := s.store.Delete(sessionIndexKey(session.Key)); err != nil { + return err + } + + return nil +} + +func sessionID(id influxdb.ID) string { + return storePrefix + id.String() +} + +func sessionIndexKey(key string) string { + return storeIndex + key +} diff --git a/session/storage_test.go b/session/storage_test.go new file mode 100644 index 0000000000..3e208c65e3 --- /dev/null +++ b/session/storage_test.go @@ -0,0 +1,120 @@ +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/inmem" + "github.com/influxdata/influxdb/v2/session" +) + +func TestSessionStore(t *testing.T) { + driver := func() session.Store { + return inmem.NewSessionStore() + } + + expected := &influxdb.Session{ + ID: 1, + Key: "2", + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + } + + simpleSetup := func(t *testing.T, store *session.Storage) { + err := store.CreateSession( + context.Background(), + expected, + ) + if err != nil { + t.Fatal(err) + } + } + + st := []struct { + name string + setup func(*testing.T, *session.Storage) + update func(*testing.T, *session.Storage) + results func(*testing.T, *session.Storage) + }{ + { + name: "create", + setup: simpleSetup, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + }, + }, + { + name: "get", + setup: simpleSetup, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + + session, err = store.FindSessionByKey(context.Background(), "2") + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(session, expected) { + t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected) + } + }, + }, + { + name: "delete", + setup: simpleSetup, + update: func(t *testing.T, store *session.Storage) { + err := store.DeleteSession(context.Background(), 1) + if err != nil { + t.Fatal(err) + } + }, + results: func(t *testing.T, store *session.Storage) { + session, err := store.FindSessionByID(context.Background(), 1) + if err == nil { + t.Fatal("expected error on deleted session but got none") + } + + if session != nil { + t.Fatal("got a session when none should have existed") + } + }, + }, + } + for _, testScenario := range st { + t.Run(testScenario.name, func(t *testing.T) { + ss := session.NewStorage(driver()) + + // setup + if testScenario.setup != nil { + testScenario.setup(t, ss) + } + + // update + if testScenario.update != nil { + testScenario.update(t, ss) + } + + // results + if testScenario.results != nil { + testScenario.results(t, ss) + } + }) + } +} diff --git a/task/backend/analytical_storage.go b/task/backend/analytical_storage.go index 38b00fe1b9..9820f40d67 100644 --- a/task/backend/analytical_storage.go +++ b/task/backend/analytical_storage.go @@ -326,7 +326,7 @@ func (as *AnalyticalStorage) RetryRun(ctx context.Context, taskID, runID influxd return run, err } - // try finding the run (in our system or underlieing) + // try finding the run (in our system or underlying) run, err = as.FindRunByID(ctx, taskID, runID) if err != nil { return run, err diff --git a/tenant/middleware_onboarding_auth.go b/tenant/middleware_onboarding_auth.go index daf5304b7c..2861874b50 100644 --- a/tenant/middleware_onboarding_auth.go +++ b/tenant/middleware_onboarding_auth.go @@ -24,12 +24,12 @@ func NewAuthedOnboardSvc(s influxdb.OnboardingService) *AuthedOnboardSvc { } } -// IsOnboarding pass through. this is handled by the underlieing service layer +// IsOnboarding pass through. this is handled by the underlying service layer func (s *AuthedOnboardSvc) IsOnboarding(ctx context.Context) (bool, error) { return s.s.IsOnboarding(ctx) } -// OnboardInitialUser pass through. this is handled by the underlieing service layer +// OnboardInitialUser pass through. this is handled by the underlying service layer func (s *AuthedOnboardSvc) OnboardInitialUser(ctx context.Context, req *influxdb.OnboardingRequest) (*influxdb.OnboardingResults, error) { return s.s.OnboardInitialUser(ctx, req) } diff --git a/tenant/storage.go b/tenant/storage.go index b2a0e19f82..611d6e3d45 100644 --- a/tenant/storage.go +++ b/tenant/storage.go @@ -68,6 +68,11 @@ func (s *Store) setup() error { if _, err := tx.Bucket(urmBucket); err != nil { return err } + + if _, err := tx.Bucket(urmByUserIndexBucket); err != nil { + return err + } + if _, err := tx.Bucket(organizationBucket); err != nil { return err } diff --git a/testing/session.go b/testing/session.go index e926caecc9..cb794905cd 100644 --- a/testing/session.go +++ b/testing/session.go @@ -9,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/influxdata/influxdb/v2" platform "github.com/influxdata/influxdb/v2" "github.com/influxdata/influxdb/v2/mock" ) @@ -274,12 +275,12 @@ func ExpireSession( diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t) session, err := s.FindSession(ctx, tt.args.key) - if err == nil { + if err.Error() != influxdb.ErrSessionExpired && err.Error() != influxdb.ErrSessionNotFound { t.Errorf("expected session to be expired got %v", err) } - if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" { - t.Errorf("session is different -got/+want\ndiff %s", diff) + if session != nil { + t.Errorf("expected a nil session but got: %v", session) } }) }