feat(session): Build out a new session service (#17950)

This new session service has the ability to work independant of other systems
it relies on having its own store type which should allow us to be more flexible
then using the built in kv system.

I have included an in mem session store.
pull/18049/head
Lyon Hill 2020-05-11 15:04:11 -06:00 committed by GitHub
parent c6b2fc5d2c
commit bdc882f6ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1186 additions and 8 deletions

86
inmem/session_store.go Normal file
View File

@ -0,0 +1,86 @@
package inmem
import (
"errors"
"sync"
"time"
)
type SessionStore struct {
data map[string]string
timers map[string]*time.Timer
mu sync.RWMutex
}
func NewSessionStore() *SessionStore {
return &SessionStore{
data: map[string]string{},
timers: map[string]*time.Timer{},
}
}
func (s *SessionStore) Set(key, val string, expireAt time.Time) error {
if !expireAt.IsZero() && expireAt.Before(time.Now()) {
// key is already expired. no problem
return nil
}
s.mu.Lock()
s.data[key] = val
s.mu.Unlock()
if !expireAt.IsZero() {
return s.ExpireAt(key, expireAt)
}
return nil
}
func (s *SessionStore) Get(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.data[key], nil
}
func (s *SessionStore) Delete(key string) error {
s.mu.Lock()
defer s.mu.Unlock()
timer := s.timers[key]
if timer != nil {
timer.Stop()
}
delete(s.data, key)
delete(s.timers, key)
return nil
}
func (s *SessionStore) ExpireAt(key string, expireAt time.Time) error {
s.mu.Lock()
existingTimer, ok := s.timers[key]
if ok {
if !existingTimer.Stop() {
return errors.New("session has expired")
}
}
duration := time.Until(expireAt)
if duration <= 0 {
s.mu.Unlock()
s.Delete(key)
return nil
}
s.timers[key] = time.AfterFunc(time.Until(expireAt), s.timerExpireFunc(key))
s.mu.Unlock()
return nil
}
func (s *SessionStore) timerExpireFunc(key string) func() {
return func() {
s.Delete(key)
}
}

162
inmem/session_store_test.go Normal file
View File

@ -0,0 +1,162 @@
package inmem_test
import (
"testing"
"time"
"github.com/influxdata/influxdb/v2/inmem"
)
func TestSessionSet(t *testing.T) {
st := inmem.NewSessionStore()
err := st.Set("hi", "friend", time.Time{})
if err != nil {
t.Fatal(err)
}
err = st.Set("hi", "enemy", time.Time{})
if err != nil {
t.Fatal(err)
}
word, err := st.Get("hi")
if err != nil {
t.Fatal(err)
}
if word != "enemy" {
t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word)
}
}
func TestSessionGet(t *testing.T) {
st := inmem.NewSessionStore()
err := st.Set("hi", "friend", time.Now().Add(time.Second))
if err != nil {
t.Fatal(err)
}
word, err := st.Get("hi")
if err != nil {
t.Fatal(err)
}
if word != "friend" {
t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word)
}
time.Sleep(time.Second * 2)
word, err = st.Get("hi")
if err != nil {
t.Fatal(err)
}
if word != "" {
t.Fatalf("expected no words back but got: %s", word)
}
}
func TestSessionDelete(t *testing.T) {
st := inmem.NewSessionStore()
err := st.Set("hi", "friend", time.Time{})
if err != nil {
t.Fatal(err)
}
if err := st.Delete("hi"); err != nil {
t.Fatal(err)
}
word, err := st.Get("hi")
if err != nil {
t.Fatal(err)
}
if word != "" {
t.Fatalf("expected no words back but got: %s", word)
}
if err := st.Delete("hi"); err != nil {
t.Fatal(err)
}
}
func TestSessionExpireAt(t *testing.T) {
st := inmem.NewSessionStore()
err := st.Set("hi", "friend", time.Time{})
if err != nil {
t.Fatal(err)
}
if err := st.ExpireAt("hi", time.Now().Add(-20)); err != nil {
t.Fatal(err)
}
word, err := st.Get("hi")
if err != nil {
t.Fatal(err)
}
if word != "" {
t.Fatalf("expected no words back but got: %s", word)
}
if err := st.Set("hello", "friend", time.Time{}); err != nil {
t.Fatal(err)
}
if err := st.ExpireAt("hello", time.Now()); err != nil {
t.Fatal(err)
}
word, err = st.Get("hello")
if err != nil {
t.Fatal(err)
}
if word != "" {
t.Fatalf("expected no words back but got: %s", word)
}
if err := st.Set("yo", "friend", time.Time{}); err != nil {
t.Fatal(err)
}
if err := st.ExpireAt("yo", time.Now().Add(100*time.Microsecond)); err != nil {
t.Fatal(err)
}
word, err = st.Get("yo")
if err != nil {
t.Fatal(err)
}
if word != "friend" {
t.Fatalf("expected no words back but got: %q", word)
}
// add more time to a key
if err := st.ExpireAt("yo", time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
word, err = st.Get("yo")
if err != nil {
t.Fatal(err)
}
if word != "friend" {
t.Fatalf("expected key to still exist but got: %q", word)
}
time.Sleep(time.Second)
word, err = st.Get("yo")
if err != nil {
t.Fatal(err)
}
if word != "" {
t.Fatalf("expected no words back but got: %s", word)
}
}

View File

@ -79,8 +79,7 @@ func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Sessio
}
if err := sess.Expired(); err != nil {
// todo(leodido) > do we want to return session also if expired?
return sess, &influxdb.Error{
return nil, &influxdb.Error{
Err: err,
}
}

View File

@ -92,5 +92,7 @@ type SessionService interface {
FindSession(ctx context.Context, key string) (*Session, error)
ExpireSession(ctx context.Context, key string) error
CreateSession(ctx context.Context, user string) (*Session, error)
// TODO: update RenewSession to take a ID instead of a session.
// By taking a session object it could be confused to update more things about the session
RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error
}

14
session/errors.go Normal file
View File

@ -0,0 +1,14 @@
package session
import (
"github.com/influxdata/influxdb/v2"
)
var (
// ErrUnauthorized when a session request is unauthorized
// usually due to password missmatch
ErrUnauthorized = &influxdb.Error{
Code: influxdb.EUnauthorized,
Msg: "unauthorized access",
}
)

194
session/http_server.go Normal file
View File

@ -0,0 +1,194 @@
package session
import (
"context"
"net/http"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/influxdata/influxdb/v2"
kithttp "github.com/influxdata/influxdb/v2/kit/transport/http"
"go.uber.org/zap"
)
const (
prefixSignIn = "/api/v2/signin"
prefixSignOut = "/api/v2/signout"
)
// SessionHandler represents an HTTP API handler for authorizations.
type SessionHandler struct {
chi.Router
api *kithttp.API
log *zap.Logger
sessionSvc influxdb.SessionService
passSvc influxdb.PasswordsService
userSvc influxdb.UserService
}
// NewSessionHandler returns a new instance of SessionHandler.
func NewSessionHandler(log *zap.Logger, sessionSvc influxdb.SessionService, userSvc influxdb.UserService, passwordsSvc influxdb.PasswordsService) *SessionHandler {
svr := &SessionHandler{
api: kithttp.NewAPI(kithttp.WithLog(log)),
log: log,
passSvc: passwordsSvc,
sessionSvc: sessionSvc,
userSvc: userSvc,
}
return svr
}
type resourceHandler struct {
prefix string
*SessionHandler
}
// Prefix is necessary to mount the router as a resource handler
func (r resourceHandler) Prefix() string { return r.prefix }
// SignInResourceHandler allows us to return 2 different rousource handler
// for the appropriate mounting location
func (h SessionHandler) SignInResourceHandler() *resourceHandler {
h.Router = chi.NewRouter()
h.Router.Use(
middleware.Recoverer,
middleware.RequestID,
middleware.RealIP,
)
h.Router.Post("/", h.handleSignin)
return &resourceHandler{prefix: prefixSignIn, SessionHandler: &h}
}
// SignOutResourceHandler allows us to return 2 different rousource handler
// for the appropriate mounting location
func (h SessionHandler) SignOutResourceHandler() *resourceHandler {
h.Router = chi.NewRouter()
h.Router.Use(
middleware.Recoverer,
middleware.RequestID,
middleware.RealIP,
)
h.Router.Post("/", h.handleSignout)
return &resourceHandler{prefix: prefixSignOut, SessionHandler: &h}
}
// handleSignin is the HTTP handler for the POST /signin route.
func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, decErr := decodeSigninRequest(ctx, r)
if decErr != nil {
h.api.Err(w, ErrUnauthorized)
return
}
u, err := h.userSvc.FindUser(ctx, influxdb.UserFilter{
Name: &req.Username,
})
if err != nil {
h.api.Err(w, ErrUnauthorized)
return
}
if err := h.passSvc.ComparePassword(ctx, u.ID, req.Password); err != nil {
h.api.Err(w, ErrUnauthorized)
return
}
s, e := h.sessionSvc.CreateSession(ctx, req.Username)
if e != nil {
h.api.Err(w, ErrUnauthorized)
return
}
encodeCookieSession(w, s)
w.WriteHeader(http.StatusNoContent)
}
type signinRequest struct {
Username string
Password string
}
func decodeSigninRequest(ctx context.Context, r *http.Request) (*signinRequest, *influxdb.Error) {
u, p, ok := r.BasicAuth()
if !ok {
return nil, &influxdb.Error{
Code: influxdb.EInvalid,
Msg: "invalid basic auth",
}
}
return &signinRequest{
Username: u,
Password: p,
}, nil
}
// handleSignout is the HTTP handler for the POST /signout route.
func (h *SessionHandler) handleSignout(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, err := decodeSignoutRequest(ctx, r)
if err != nil {
h.api.Err(w, ErrUnauthorized)
return
}
if err := h.sessionSvc.ExpireSession(ctx, req.Key); err != nil {
h.api.Err(w, ErrUnauthorized)
return
}
w.WriteHeader(http.StatusNoContent)
}
type signoutRequest struct {
Key string
}
func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) {
key, err := decodeCookieSession(ctx, r)
if err != nil {
return nil, err
}
return &signoutRequest{
Key: key,
}, nil
}
const cookieSessionName = "session"
func encodeCookieSession(w http.ResponseWriter, s *influxdb.Session) {
c := &http.Cookie{
Name: cookieSessionName,
Value: s.Key,
}
http.SetCookie(w, c)
}
func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) {
c, err := r.Cookie(cookieSessionName)
if err != nil {
return "", &influxdb.Error{
Code: influxdb.EInvalid,
Err: err,
}
}
return c.Value, nil
}
// SetCookieSession adds a cookie for the session to an http request
func SetCookieSession(key string, r *http.Request) {
c := &http.Cookie{
Name: cookieSessionName,
Value: key,
Secure: true,
}
r.AddCookie(c)
}

View File

@ -0,0 +1,98 @@
package session
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/mock"
"go.uber.org/zap/zaptest"
)
func TestSessionHandler_handleSignin(t *testing.T) {
type fields struct {
PasswordsService influxdb.PasswordsService
SessionService influxdb.SessionService
}
type args struct {
user string
password string
}
type wants struct {
cookie string
code int
}
tests := []struct {
name string
fields fields
args args
wants wants
}{
{
name: "successful compare password",
fields: fields{
SessionService: &mock.SessionService{
CreateSessionFn: func(context.Context, string) (*influxdb.Session, error) {
return &influxdb.Session{
ID: influxdb.ID(0),
Key: "abc123xyz",
CreatedAt: time.Date(2018, 9, 26, 0, 0, 0, 0, time.UTC),
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
UserID: influxdb.ID(1),
}, nil
},
},
PasswordsService: &mock.PasswordsService{
ComparePasswordFn: func(context.Context, influxdb.ID, string) error {
return nil
},
},
},
args: args{
user: "user1",
password: "supersecret",
},
wants: wants{
cookie: "session=abc123xyz",
code: http.StatusNoContent,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userSVC := mock.NewUserService()
userSVC.FindUserFn = func(_ context.Context, f influxdb.UserFilter) (*influxdb.User, error) {
return &influxdb.User{ID: 1}, nil
}
h := NewSessionHandler(zaptest.NewLogger(t), tt.fields.SessionService, userSVC, tt.fields.PasswordsService)
server := httptest.NewServer(h.SignInResourceHandler())
client := server.Client()
r, err := http.NewRequest("POST", server.URL, nil)
if err != nil {
t.Fatal(err)
}
r.SetBasicAuth(tt.args.user, tt.args.password)
resp, err := client.Do(r)
if err != nil {
t.Fatal(err)
}
if got, want := resp.StatusCode, tt.wants.code; got != want {
t.Errorf("bad status code: got %d want %d", got, want)
}
cookie := resp.Header.Get("Set-Cookie")
if got, want := cookie, tt.wants.cookie; got != want {
t.Errorf("expected session cookie to be set: got %q want %q", got, want)
}
})
}
}

View File

@ -0,0 +1,81 @@
package session
import (
"context"
"time"
"github.com/influxdata/influxdb/v2"
"go.uber.org/zap"
)
// SessionLogger is a logger service middleware for sessions
type SessionLogger struct {
logger *zap.Logger
sessionService influxdb.SessionService
}
var _ influxdb.SessionService = (*SessionLogger)(nil)
// NewSessionLogger returns a logging service middleware for the User Service.
func NewSessionLogger(log *zap.Logger, s influxdb.SessionService) *SessionLogger {
return &SessionLogger{
logger: log,
sessionService: s,
}
}
// FindSession calls the underlying session service and logs the results of the request
func (l *SessionLogger) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) {
defer func(start time.Time) {
dur := zap.Duration("took", time.Since(start))
if err != nil {
l.logger.Error("failed to session find", zap.Error(err), dur)
return
}
l.logger.Debug("session find", dur)
}(time.Now())
return l.sessionService.FindSession(ctx, key)
}
// ExpireSession calls the underlying session service and logs the results of the request
func (l *SessionLogger) ExpireSession(ctx context.Context, key string) (err error) {
defer func(start time.Time) {
dur := zap.Duration("took", time.Since(start))
if err != nil {
l.logger.Error("failed to session expire", zap.Error(err), dur)
return
}
l.logger.Debug("session expire", dur)
}(time.Now())
return l.sessionService.ExpireSession(ctx, key)
}
// CreateSession calls the underlying session service and logs the results of the request
func (l *SessionLogger) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) {
defer func(start time.Time) {
dur := zap.Duration("took", time.Since(start))
if err != nil {
l.logger.Error("failed to session create", zap.Error(err), dur)
return
}
l.logger.Debug("session create", dur)
}(time.Now())
return l.sessionService.CreateSession(ctx, user)
}
// RenewSession calls the underlying session service and logs the results of the request
func (l *SessionLogger) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) {
defer func(start time.Time) {
dur := zap.Duration("took", time.Since(start))
if err != nil {
l.logger.Error("failed to session renew", zap.Error(err), dur)
return
}
l.logger.Debug("session renew", dur)
}(time.Now())
return l.sessionService.RenewSession(ctx, session, newExpiration)
}

View File

@ -0,0 +1,56 @@
package session
import (
"context"
"time"
"github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/kit/metric"
"github.com/prometheus/client_golang/prometheus"
)
// SessionMetrics is a metrics middleware system for the session service
type SessionMetrics struct {
// RED metrics
rec *metric.REDClient
sessionSvc influxdb.SessionService
}
var _ influxdb.SessionService = (*SessionMetrics)(nil)
// NewSessionMetrics creates a new session metrics middleware
func NewSessionMetrics(reg prometheus.Registerer, s influxdb.SessionService) *SessionMetrics {
return &SessionMetrics{
rec: metric.New(reg, "session"),
sessionSvc: s,
}
}
// FindSession calls the underlying session service and tracks RED metrics for the call
func (m *SessionMetrics) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) {
rec := m.rec.Record("find_session")
session, err = m.sessionSvc.FindSession(ctx, key)
return session, rec(err)
}
// ExpireSession calls the underlying session service and tracks RED metrics for the call
func (m *SessionMetrics) ExpireSession(ctx context.Context, key string) (err error) {
rec := m.rec.Record("expire_session")
err = m.sessionSvc.ExpireSession(ctx, key)
return rec(err)
}
// CreateSession calls the underlying session service and tracks RED metrics for the call
func (m *SessionMetrics) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) {
rec := m.rec.Record("create_session")
s, err = m.sessionSvc.CreateSession(ctx, user)
return s, rec(err)
}
// RenewSession calls the underlying session service and tracks RED metrics for the call
func (m *SessionMetrics) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) {
rec := m.rec.Record("renew_session")
err = m.sessionSvc.RenewSession(ctx, session, newExpiration)
return rec(err)
}

175
session/service.go Normal file
View File

@ -0,0 +1,175 @@
package session
import (
"context"
"time"
"github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/rand"
"github.com/influxdata/influxdb/v2/snowflake"
)
// Service implements the influxdb.SessionService interface and
// handles communication between session and the necessary user and urm services
type Service struct {
store *Storage
userService influxdb.UserService
urmService influxdb.UserResourceMappingService
authService influxdb.AuthorizationService
sessionLength time.Duration
idGen influxdb.IDGenerator
tokenGen influxdb.TokenGenerator
disableAuthorizationsForMaxPermissions func(context.Context) bool
}
// NewService creates a new session service
func NewService(store *Storage, userService influxdb.UserService, urmService influxdb.UserResourceMappingService, authSvc influxdb.AuthorizationService, sessionLength time.Duration) *Service {
if sessionLength <= 0 {
sessionLength = time.Hour
}
return &Service{
store: store,
userService: userService,
urmService: urmService,
authService: authSvc,
sessionLength: sessionLength,
idGen: snowflake.NewIDGenerator(),
tokenGen: rand.NewTokenGenerator(64),
disableAuthorizationsForMaxPermissions: func(context.Context) bool {
return false
},
}
}
// WithMaxPermissionFunc sets the useAuthorizationsForMaxPermissions function
// which can trigger whether or not max permissions uses the users authorizations
// to derive maximum permissions.
func (s *Service) WithMaxPermissionFunc(fn func(context.Context) bool) {
s.disableAuthorizationsForMaxPermissions = fn
}
// FindSession finds a session based on the session key
func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Session, error) {
session, err := s.store.FindSessionByKey(ctx, key)
if err != nil {
return nil, err
}
// TODO: We want to be able to store permissions in the session
// but the contract provided by urm's doesn't give us enough information to quickly repopulate our
// session permissions on updates so we are required to pull the permissions every time we find the session.
permissions, err := s.getPermissionSet(ctx, session.UserID)
if err != nil {
return nil, err
}
session.Permissions = permissions
return session, nil
}
// ExpireSession removes a session from the system
func (s *Service) ExpireSession(ctx context.Context, key string) error {
session, err := s.store.FindSessionByKey(ctx, key)
if err != nil {
return err
}
return s.store.DeleteSession(ctx, session.ID)
}
// CreateSession
func (s *Service) CreateSession(ctx context.Context, user string) (*influxdb.Session, error) {
u, err := s.userService.FindUser(ctx, influxdb.UserFilter{
Name: &user,
})
if err != nil {
return nil, err
}
token, err := s.tokenGen.Token()
if err != nil {
return nil, err
}
// for now we are not storing the permissions because we need to pull them every time we find
// so we might as well keep the session stored small
now := time.Now()
session := &influxdb.Session{
ID: s.idGen.ID(),
Key: token,
CreatedAt: now,
ExpiresAt: now.Add(s.sessionLength),
UserID: u.ID,
}
return session, s.store.CreateSession(ctx, session)
}
// RenewSession update the sessions expiration time
func (s *Service) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) error {
if session == nil {
return &influxdb.Error{
Msg: "session is nil",
}
}
return s.store.RefreshSession(ctx, session.ID, newExpiration)
}
func (s *Service) getPermissionSet(ctx context.Context, uid influxdb.ID) ([]influxdb.Permission, error) {
mappings, _, err := s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Limit: 100})
if err != nil {
return nil, err
}
permissions, err := permissionFromMapping(mappings)
if err != nil {
return nil, err
}
if len(mappings) == 100 {
// if we got 100 mappings we probably need to pull more pages
// account for paginated results
for i := len(mappings); len(mappings) > 0; i += len(mappings) {
mappings, _, err = s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Offset: i, Limit: 100})
if err != nil {
return nil, err
}
pms, err := permissionFromMapping(mappings)
if err != nil {
return nil, err
}
permissions = append(permissions, pms...)
}
}
if !s.disableAuthorizationsForMaxPermissions(ctx) {
as, _, err := s.authService.FindAuthorizations(ctx, influxdb.AuthorizationFilter{UserID: &uid})
if err != nil {
return nil, err
}
for _, a := range as {
permissions = append(permissions, a.Permissions...)
}
}
permissions = append(permissions, influxdb.MePermissions(uid)...)
return permissions, nil
}
func permissionFromMapping(mappings []*influxdb.UserResourceMapping) ([]influxdb.Permission, error) {
ps := make([]influxdb.Permission, 0, len(mappings))
for _, m := range mappings {
p, err := m.ToPermissions()
if err != nil {
return nil, &influxdb.Error{
Err: err,
}
}
ps = append(ps, p...)
}
return ps, nil
}

49
session/service_test.go Normal file
View File

@ -0,0 +1,49 @@
package session
import (
"context"
"testing"
"time"
"github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/inmem"
"github.com/influxdata/influxdb/v2/mock"
"github.com/influxdata/influxdb/v2/tenant"
influxdbtesting "github.com/influxdata/influxdb/v2/testing"
)
func TestSessionService(t *testing.T) {
influxdbtesting.SessionService(initSessionService, t)
}
func initSessionService(f influxdbtesting.SessionFields, t *testing.T) (influxdb.SessionService, string, func()) {
ss := NewStorage(inmem.NewSessionStore())
ts, _ := tenant.NewStore(inmem.NewKVStore())
ten := tenant.NewService(ts)
svc := NewService(ss, ten, ten, &mock.AuthorizationService{
FindAuthorizationsFn: func(context.Context, influxdb.AuthorizationFilter, ...influxdb.FindOptions) ([]*influxdb.Authorization, int, error) {
return []*influxdb.Authorization{}, 0, nil
},
}, time.Minute)
if f.IDGenerator != nil {
svc.idGen = f.IDGenerator
}
if f.TokenGenerator != nil {
svc.tokenGen = f.TokenGenerator
}
ctx := context.Background()
for _, u := range f.Users {
if err := ten.CreateUser(ctx, u); err != nil {
t.Fatalf("failed to populate users")
}
}
for _, s := range f.Sessions {
if err := ss.CreateSession(ctx, s); err != nil {
t.Fatalf("failed to populate sessions")
}
}
return svc, "session", func() {}
}

136
session/storage.go Normal file
View File

@ -0,0 +1,136 @@
package session
import (
"context"
"encoding/json"
"time"
"github.com/influxdata/influxdb/v2"
)
type Store interface {
Set(key, val string, expireAt time.Time) error
Get(key string) (string, error)
Delete(key string) error
ExpireAt(key string, expireAt time.Time) error
}
var storePrefix = "sessionsv1/"
var storeIndex = "sessionsindexv1/"
// Storage is a store translation layer between the data storage unit and the
// service layer.
type Storage struct {
store Store
}
// NewStorage creates a new storage system
func NewStorage(s Store) *Storage {
return &Storage{s}
}
// FindSessionByKey use a given key to retrieve the stored session
func (s *Storage) FindSessionByKey(ctx context.Context, key string) (*influxdb.Session, error) {
val, err := s.store.Get(sessionIndexKey(key))
if err != nil {
return nil, err
}
if val == "" {
return nil, &influxdb.Error{
Code: influxdb.ENotFound,
Msg: influxdb.ErrSessionNotFound,
}
}
id, err := influxdb.IDFromString(val)
if err != nil {
return nil, err
}
return s.FindSessionByID(ctx, *id)
}
// FindSessionByID use a provided id to retrieve the stored session
func (s *Storage) FindSessionByID(ctx context.Context, id influxdb.ID) (*influxdb.Session, error) {
val, err := s.store.Get(storePrefix + id.String())
if err != nil {
return nil, err
}
if val == "" {
return nil, &influxdb.Error{
Code: influxdb.ENotFound,
Msg: influxdb.ErrSessionNotFound,
}
}
session := &influxdb.Session{}
return session, json.Unmarshal([]byte(val), session)
}
// CreateSession creates a new session
func (s *Storage) CreateSession(ctx context.Context, session *influxdb.Session) error {
// create session
sessionBytes, err := json.Marshal(session)
if err != nil {
return err
}
// use a minute time just so the session will expire if we fail to set the expiration later
sessionID := sessionID(session.ID)
if err := s.store.Set(sessionID, string(sessionBytes), session.ExpiresAt); err != nil {
return err
}
// create index
indexKey := sessionIndexKey(session.Key)
if err := s.store.Set(indexKey, session.ID.String(), session.ExpiresAt); err != nil {
return err
}
return nil
}
// RefreshSession updates the expiration time of a session.
func (s *Storage) RefreshSession(ctx context.Context, id influxdb.ID, expireAt time.Time) error {
session, err := s.FindSessionByID(ctx, id)
if err != nil {
return err
}
if expireAt.Before(session.ExpiresAt) {
// no need to recreate the session if we aren't extending the expiration
return nil
}
session.ExpiresAt = expireAt
return s.CreateSession(ctx, session)
}
// DeleteSession removes the session and index from storage
func (s *Storage) DeleteSession(ctx context.Context, id influxdb.ID) error {
session, err := s.FindSessionByID(ctx, id)
if err != nil {
return err
}
if session == nil {
return nil
}
if err := s.store.Delete(sessionID(session.ID)); err != nil {
return err
}
if err := s.store.Delete(sessionIndexKey(session.Key)); err != nil {
return err
}
return nil
}
func sessionID(id influxdb.ID) string {
return storePrefix + id.String()
}
func sessionIndexKey(key string) string {
return storeIndex + key
}

120
session/storage_test.go Normal file
View File

@ -0,0 +1,120 @@
package session_test
import (
"context"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/inmem"
"github.com/influxdata/influxdb/v2/session"
)
func TestSessionStore(t *testing.T) {
driver := func() session.Store {
return inmem.NewSessionStore()
}
expected := &influxdb.Session{
ID: 1,
Key: "2",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
}
simpleSetup := func(t *testing.T, store *session.Storage) {
err := store.CreateSession(
context.Background(),
expected,
)
if err != nil {
t.Fatal(err)
}
}
st := []struct {
name string
setup func(*testing.T, *session.Storage)
update func(*testing.T, *session.Storage)
results func(*testing.T, *session.Storage)
}{
{
name: "create",
setup: simpleSetup,
results: func(t *testing.T, store *session.Storage) {
session, err := store.FindSessionByID(context.Background(), 1)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(session, expected) {
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
}
},
},
{
name: "get",
setup: simpleSetup,
results: func(t *testing.T, store *session.Storage) {
session, err := store.FindSessionByID(context.Background(), 1)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(session, expected) {
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
}
session, err = store.FindSessionByKey(context.Background(), "2")
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(session, expected) {
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
}
},
},
{
name: "delete",
setup: simpleSetup,
update: func(t *testing.T, store *session.Storage) {
err := store.DeleteSession(context.Background(), 1)
if err != nil {
t.Fatal(err)
}
},
results: func(t *testing.T, store *session.Storage) {
session, err := store.FindSessionByID(context.Background(), 1)
if err == nil {
t.Fatal("expected error on deleted session but got none")
}
if session != nil {
t.Fatal("got a session when none should have existed")
}
},
},
}
for _, testScenario := range st {
t.Run(testScenario.name, func(t *testing.T) {
ss := session.NewStorage(driver())
// setup
if testScenario.setup != nil {
testScenario.setup(t, ss)
}
// update
if testScenario.update != nil {
testScenario.update(t, ss)
}
// results
if testScenario.results != nil {
testScenario.results(t, ss)
}
})
}
}

View File

@ -326,7 +326,7 @@ func (as *AnalyticalStorage) RetryRun(ctx context.Context, taskID, runID influxd
return run, err
}
// try finding the run (in our system or underlieing)
// try finding the run (in our system or underlying)
run, err = as.FindRunByID(ctx, taskID, runID)
if err != nil {
return run, err

View File

@ -24,12 +24,12 @@ func NewAuthedOnboardSvc(s influxdb.OnboardingService) *AuthedOnboardSvc {
}
}
// IsOnboarding pass through. this is handled by the underlieing service layer
// IsOnboarding pass through. this is handled by the underlying service layer
func (s *AuthedOnboardSvc) IsOnboarding(ctx context.Context) (bool, error) {
return s.s.IsOnboarding(ctx)
}
// OnboardInitialUser pass through. this is handled by the underlieing service layer
// OnboardInitialUser pass through. this is handled by the underlying service layer
func (s *AuthedOnboardSvc) OnboardInitialUser(ctx context.Context, req *influxdb.OnboardingRequest) (*influxdb.OnboardingResults, error) {
return s.s.OnboardInitialUser(ctx, req)
}

View File

@ -68,6 +68,11 @@ func (s *Store) setup() error {
if _, err := tx.Bucket(urmBucket); err != nil {
return err
}
if _, err := tx.Bucket(urmByUserIndexBucket); err != nil {
return err
}
if _, err := tx.Bucket(organizationBucket); err != nil {
return err
}

View File

@ -9,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/influxdata/influxdb/v2"
platform "github.com/influxdata/influxdb/v2"
"github.com/influxdata/influxdb/v2/mock"
)
@ -274,12 +275,12 @@ func ExpireSession(
diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t)
session, err := s.FindSession(ctx, tt.args.key)
if err == nil {
if err.Error() != influxdb.ErrSessionExpired && err.Error() != influxdb.ErrSessionNotFound {
t.Errorf("expected session to be expired got %v", err)
}
if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" {
t.Errorf("session is different -got/+want\ndiff %s", diff)
if session != nil {
t.Errorf("expected a nil session but got: %v", session)
}
})
}