feat(session): Build out a new session service (#17950)
This new session service has the ability to work independant of other systems it relies on having its own store type which should allow us to be more flexible then using the built in kv system. I have included an in mem session store.pull/18049/head
parent
c6b2fc5d2c
commit
bdc882f6ce
|
@ -0,0 +1,86 @@
|
|||
package inmem
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionStore struct {
|
||||
data map[string]string
|
||||
|
||||
timers map[string]*time.Timer
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{
|
||||
data: map[string]string{},
|
||||
timers: map[string]*time.Timer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(key, val string, expireAt time.Time) error {
|
||||
if !expireAt.IsZero() && expireAt.Before(time.Now()) {
|
||||
// key is already expired. no problem
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.data[key] = val
|
||||
s.mu.Unlock()
|
||||
|
||||
if !expireAt.IsZero() {
|
||||
return s.ExpireAt(key, expireAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.data[key], nil
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(key string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
timer := s.timers[key]
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
delete(s.data, key)
|
||||
delete(s.timers, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SessionStore) ExpireAt(key string, expireAt time.Time) error {
|
||||
s.mu.Lock()
|
||||
|
||||
existingTimer, ok := s.timers[key]
|
||||
if ok {
|
||||
if !existingTimer.Stop() {
|
||||
return errors.New("session has expired")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
duration := time.Until(expireAt)
|
||||
if duration <= 0 {
|
||||
s.mu.Unlock()
|
||||
s.Delete(key)
|
||||
return nil
|
||||
}
|
||||
s.timers[key] = time.AfterFunc(time.Until(expireAt), s.timerExpireFunc(key))
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SessionStore) timerExpireFunc(key string) func() {
|
||||
return func() {
|
||||
s.Delete(key)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package inmem_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2/inmem"
|
||||
)
|
||||
|
||||
func TestSessionSet(t *testing.T) {
|
||||
st := inmem.NewSessionStore()
|
||||
err := st.Set("hi", "friend", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = st.Set("hi", "enemy", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err := st.Get("hi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "enemy" {
|
||||
t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word)
|
||||
}
|
||||
}
|
||||
func TestSessionGet(t *testing.T) {
|
||||
st := inmem.NewSessionStore()
|
||||
err := st.Set("hi", "friend", time.Now().Add(time.Second))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err := st.Get("hi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "friend" {
|
||||
t.Fatalf("got incorrect response: got %s expected: \"enemy\"", word)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
word, err = st.Get("hi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "" {
|
||||
t.Fatalf("expected no words back but got: %s", word)
|
||||
}
|
||||
}
|
||||
func TestSessionDelete(t *testing.T) {
|
||||
st := inmem.NewSessionStore()
|
||||
err := st.Set("hi", "friend", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := st.Delete("hi"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err := st.Get("hi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "" {
|
||||
t.Fatalf("expected no words back but got: %s", word)
|
||||
}
|
||||
|
||||
if err := st.Delete("hi"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
func TestSessionExpireAt(t *testing.T) {
|
||||
st := inmem.NewSessionStore()
|
||||
err := st.Set("hi", "friend", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := st.ExpireAt("hi", time.Now().Add(-20)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err := st.Get("hi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "" {
|
||||
t.Fatalf("expected no words back but got: %s", word)
|
||||
}
|
||||
|
||||
if err := st.Set("hello", "friend", time.Time{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := st.ExpireAt("hello", time.Now()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err = st.Get("hello")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "" {
|
||||
t.Fatalf("expected no words back but got: %s", word)
|
||||
}
|
||||
|
||||
if err := st.Set("yo", "friend", time.Time{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := st.ExpireAt("yo", time.Now().Add(100*time.Microsecond)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
word, err = st.Get("yo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "friend" {
|
||||
t.Fatalf("expected no words back but got: %q", word)
|
||||
}
|
||||
|
||||
// add more time to a key
|
||||
if err := st.ExpireAt("yo", time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
word, err = st.Get("yo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "friend" {
|
||||
t.Fatalf("expected key to still exist but got: %q", word)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
word, err = st.Get("yo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if word != "" {
|
||||
t.Fatalf("expected no words back but got: %s", word)
|
||||
}
|
||||
}
|
|
@ -79,8 +79,7 @@ func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Sessio
|
|||
}
|
||||
|
||||
if err := sess.Expired(); err != nil {
|
||||
// todo(leodido) > do we want to return session also if expired?
|
||||
return sess, &influxdb.Error{
|
||||
return nil, &influxdb.Error{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -92,5 +92,7 @@ type SessionService interface {
|
|||
FindSession(ctx context.Context, key string) (*Session, error)
|
||||
ExpireSession(ctx context.Context, key string) error
|
||||
CreateSession(ctx context.Context, user string) (*Session, error)
|
||||
// TODO: update RenewSession to take a ID instead of a session.
|
||||
// By taking a session object it could be confused to update more things about the session
|
||||
RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnauthorized when a session request is unauthorized
|
||||
// usually due to password missmatch
|
||||
ErrUnauthorized = &influxdb.Error{
|
||||
Code: influxdb.EUnauthorized,
|
||||
Msg: "unauthorized access",
|
||||
}
|
||||
)
|
|
@ -0,0 +1,194 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
kithttp "github.com/influxdata/influxdb/v2/kit/transport/http"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixSignIn = "/api/v2/signin"
|
||||
prefixSignOut = "/api/v2/signout"
|
||||
)
|
||||
|
||||
// SessionHandler represents an HTTP API handler for authorizations.
|
||||
type SessionHandler struct {
|
||||
chi.Router
|
||||
api *kithttp.API
|
||||
log *zap.Logger
|
||||
|
||||
sessionSvc influxdb.SessionService
|
||||
passSvc influxdb.PasswordsService
|
||||
userSvc influxdb.UserService
|
||||
}
|
||||
|
||||
// NewSessionHandler returns a new instance of SessionHandler.
|
||||
func NewSessionHandler(log *zap.Logger, sessionSvc influxdb.SessionService, userSvc influxdb.UserService, passwordsSvc influxdb.PasswordsService) *SessionHandler {
|
||||
svr := &SessionHandler{
|
||||
api: kithttp.NewAPI(kithttp.WithLog(log)),
|
||||
log: log,
|
||||
|
||||
passSvc: passwordsSvc,
|
||||
sessionSvc: sessionSvc,
|
||||
userSvc: userSvc,
|
||||
}
|
||||
|
||||
return svr
|
||||
}
|
||||
|
||||
type resourceHandler struct {
|
||||
prefix string
|
||||
*SessionHandler
|
||||
}
|
||||
|
||||
// Prefix is necessary to mount the router as a resource handler
|
||||
func (r resourceHandler) Prefix() string { return r.prefix }
|
||||
|
||||
// SignInResourceHandler allows us to return 2 different rousource handler
|
||||
// for the appropriate mounting location
|
||||
func (h SessionHandler) SignInResourceHandler() *resourceHandler {
|
||||
h.Router = chi.NewRouter()
|
||||
h.Router.Use(
|
||||
middleware.Recoverer,
|
||||
middleware.RequestID,
|
||||
middleware.RealIP,
|
||||
)
|
||||
h.Router.Post("/", h.handleSignin)
|
||||
return &resourceHandler{prefix: prefixSignIn, SessionHandler: &h}
|
||||
}
|
||||
|
||||
// SignOutResourceHandler allows us to return 2 different rousource handler
|
||||
// for the appropriate mounting location
|
||||
func (h SessionHandler) SignOutResourceHandler() *resourceHandler {
|
||||
h.Router = chi.NewRouter()
|
||||
h.Router.Use(
|
||||
middleware.Recoverer,
|
||||
middleware.RequestID,
|
||||
middleware.RealIP,
|
||||
)
|
||||
h.Router.Post("/", h.handleSignout)
|
||||
return &resourceHandler{prefix: prefixSignOut, SessionHandler: &h}
|
||||
}
|
||||
|
||||
// handleSignin is the HTTP handler for the POST /signin route.
|
||||
func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req, decErr := decodeSigninRequest(ctx, r)
|
||||
if decErr != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.userSvc.FindUser(ctx, influxdb.UserFilter{
|
||||
Name: &req.Username,
|
||||
})
|
||||
if err != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.passSvc.ComparePassword(ctx, u.ID, req.Password); err != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s, e := h.sessionSvc.CreateSession(ctx, req.Username)
|
||||
if e != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
encodeCookieSession(w, s)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type signinRequest struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func decodeSigninRequest(ctx context.Context, r *http.Request) (*signinRequest, *influxdb.Error) {
|
||||
u, p, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, &influxdb.Error{
|
||||
Code: influxdb.EInvalid,
|
||||
Msg: "invalid basic auth",
|
||||
}
|
||||
}
|
||||
|
||||
return &signinRequest{
|
||||
Username: u,
|
||||
Password: p,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSignout is the HTTP handler for the POST /signout route.
|
||||
func (h *SessionHandler) handleSignout(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req, err := decodeSignoutRequest(ctx, r)
|
||||
if err != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.sessionSvc.ExpireSession(ctx, req.Key); err != nil {
|
||||
h.api.Err(w, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type signoutRequest struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) {
|
||||
key, err := decodeCookieSession(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &signoutRequest{
|
||||
Key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const cookieSessionName = "session"
|
||||
|
||||
func encodeCookieSession(w http.ResponseWriter, s *influxdb.Session) {
|
||||
c := &http.Cookie{
|
||||
Name: cookieSessionName,
|
||||
Value: s.Key,
|
||||
}
|
||||
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) {
|
||||
c, err := r.Cookie(cookieSessionName)
|
||||
if err != nil {
|
||||
return "", &influxdb.Error{
|
||||
Code: influxdb.EInvalid,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return c.Value, nil
|
||||
}
|
||||
|
||||
// SetCookieSession adds a cookie for the session to an http request
|
||||
func SetCookieSession(key string, r *http.Request) {
|
||||
c := &http.Cookie{
|
||||
Name: cookieSessionName,
|
||||
Value: key,
|
||||
Secure: true,
|
||||
}
|
||||
|
||||
r.AddCookie(c)
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/mock"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func TestSessionHandler_handleSignin(t *testing.T) {
|
||||
type fields struct {
|
||||
PasswordsService influxdb.PasswordsService
|
||||
SessionService influxdb.SessionService
|
||||
}
|
||||
type args struct {
|
||||
user string
|
||||
password string
|
||||
}
|
||||
type wants struct {
|
||||
cookie string
|
||||
code int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wants wants
|
||||
}{
|
||||
{
|
||||
name: "successful compare password",
|
||||
fields: fields{
|
||||
SessionService: &mock.SessionService{
|
||||
CreateSessionFn: func(context.Context, string) (*influxdb.Session, error) {
|
||||
return &influxdb.Session{
|
||||
ID: influxdb.ID(0),
|
||||
Key: "abc123xyz",
|
||||
CreatedAt: time.Date(2018, 9, 26, 0, 0, 0, 0, time.UTC),
|
||||
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
|
||||
UserID: influxdb.ID(1),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
PasswordsService: &mock.PasswordsService{
|
||||
ComparePasswordFn: func(context.Context, influxdb.ID, string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
user: "user1",
|
||||
password: "supersecret",
|
||||
},
|
||||
wants: wants{
|
||||
cookie: "session=abc123xyz",
|
||||
code: http.StatusNoContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userSVC := mock.NewUserService()
|
||||
userSVC.FindUserFn = func(_ context.Context, f influxdb.UserFilter) (*influxdb.User, error) {
|
||||
return &influxdb.User{ID: 1}, nil
|
||||
}
|
||||
h := NewSessionHandler(zaptest.NewLogger(t), tt.fields.SessionService, userSVC, tt.fields.PasswordsService)
|
||||
|
||||
server := httptest.NewServer(h.SignInResourceHandler())
|
||||
client := server.Client()
|
||||
|
||||
r, err := http.NewRequest("POST", server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.SetBasicAuth(tt.args.user, tt.args.password)
|
||||
|
||||
resp, err := client.Do(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := resp.StatusCode, tt.wants.code; got != want {
|
||||
t.Errorf("bad status code: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
cookie := resp.Header.Get("Set-Cookie")
|
||||
if got, want := cookie, tt.wants.cookie; got != want {
|
||||
t.Errorf("expected session cookie to be set: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionLogger is a logger service middleware for sessions
|
||||
type SessionLogger struct {
|
||||
logger *zap.Logger
|
||||
sessionService influxdb.SessionService
|
||||
}
|
||||
|
||||
var _ influxdb.SessionService = (*SessionLogger)(nil)
|
||||
|
||||
// NewSessionLogger returns a logging service middleware for the User Service.
|
||||
func NewSessionLogger(log *zap.Logger, s influxdb.SessionService) *SessionLogger {
|
||||
return &SessionLogger{
|
||||
logger: log,
|
||||
sessionService: s,
|
||||
}
|
||||
}
|
||||
|
||||
// FindSession calls the underlying session service and logs the results of the request
|
||||
func (l *SessionLogger) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) {
|
||||
defer func(start time.Time) {
|
||||
dur := zap.Duration("took", time.Since(start))
|
||||
if err != nil {
|
||||
l.logger.Error("failed to session find", zap.Error(err), dur)
|
||||
return
|
||||
}
|
||||
l.logger.Debug("session find", dur)
|
||||
}(time.Now())
|
||||
return l.sessionService.FindSession(ctx, key)
|
||||
|
||||
}
|
||||
|
||||
// ExpireSession calls the underlying session service and logs the results of the request
|
||||
func (l *SessionLogger) ExpireSession(ctx context.Context, key string) (err error) {
|
||||
defer func(start time.Time) {
|
||||
dur := zap.Duration("took", time.Since(start))
|
||||
if err != nil {
|
||||
l.logger.Error("failed to session expire", zap.Error(err), dur)
|
||||
return
|
||||
}
|
||||
l.logger.Debug("session expire", dur)
|
||||
}(time.Now())
|
||||
return l.sessionService.ExpireSession(ctx, key)
|
||||
|
||||
}
|
||||
|
||||
// CreateSession calls the underlying session service and logs the results of the request
|
||||
func (l *SessionLogger) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) {
|
||||
defer func(start time.Time) {
|
||||
dur := zap.Duration("took", time.Since(start))
|
||||
if err != nil {
|
||||
l.logger.Error("failed to session create", zap.Error(err), dur)
|
||||
return
|
||||
}
|
||||
l.logger.Debug("session create", dur)
|
||||
}(time.Now())
|
||||
return l.sessionService.CreateSession(ctx, user)
|
||||
|
||||
}
|
||||
|
||||
// RenewSession calls the underlying session service and logs the results of the request
|
||||
func (l *SessionLogger) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) {
|
||||
defer func(start time.Time) {
|
||||
dur := zap.Duration("took", time.Since(start))
|
||||
if err != nil {
|
||||
l.logger.Error("failed to session renew", zap.Error(err), dur)
|
||||
return
|
||||
}
|
||||
l.logger.Debug("session renew", dur)
|
||||
}(time.Now())
|
||||
return l.sessionService.RenewSession(ctx, session, newExpiration)
|
||||
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/kit/metric"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// SessionMetrics is a metrics middleware system for the session service
|
||||
type SessionMetrics struct {
|
||||
// RED metrics
|
||||
rec *metric.REDClient
|
||||
|
||||
sessionSvc influxdb.SessionService
|
||||
}
|
||||
|
||||
var _ influxdb.SessionService = (*SessionMetrics)(nil)
|
||||
|
||||
// NewSessionMetrics creates a new session metrics middleware
|
||||
func NewSessionMetrics(reg prometheus.Registerer, s influxdb.SessionService) *SessionMetrics {
|
||||
return &SessionMetrics{
|
||||
rec: metric.New(reg, "session"),
|
||||
sessionSvc: s,
|
||||
}
|
||||
}
|
||||
|
||||
// FindSession calls the underlying session service and tracks RED metrics for the call
|
||||
func (m *SessionMetrics) FindSession(ctx context.Context, key string) (session *influxdb.Session, err error) {
|
||||
rec := m.rec.Record("find_session")
|
||||
session, err = m.sessionSvc.FindSession(ctx, key)
|
||||
return session, rec(err)
|
||||
}
|
||||
|
||||
// ExpireSession calls the underlying session service and tracks RED metrics for the call
|
||||
func (m *SessionMetrics) ExpireSession(ctx context.Context, key string) (err error) {
|
||||
rec := m.rec.Record("expire_session")
|
||||
err = m.sessionSvc.ExpireSession(ctx, key)
|
||||
return rec(err)
|
||||
}
|
||||
|
||||
// CreateSession calls the underlying session service and tracks RED metrics for the call
|
||||
func (m *SessionMetrics) CreateSession(ctx context.Context, user string) (s *influxdb.Session, err error) {
|
||||
rec := m.rec.Record("create_session")
|
||||
s, err = m.sessionSvc.CreateSession(ctx, user)
|
||||
return s, rec(err)
|
||||
}
|
||||
|
||||
// RenewSession calls the underlying session service and tracks RED metrics for the call
|
||||
func (m *SessionMetrics) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) (err error) {
|
||||
rec := m.rec.Record("renew_session")
|
||||
err = m.sessionSvc.RenewSession(ctx, session, newExpiration)
|
||||
return rec(err)
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/rand"
|
||||
"github.com/influxdata/influxdb/v2/snowflake"
|
||||
)
|
||||
|
||||
// Service implements the influxdb.SessionService interface and
|
||||
// handles communication between session and the necessary user and urm services
|
||||
type Service struct {
|
||||
store *Storage
|
||||
userService influxdb.UserService
|
||||
urmService influxdb.UserResourceMappingService
|
||||
authService influxdb.AuthorizationService
|
||||
sessionLength time.Duration
|
||||
|
||||
idGen influxdb.IDGenerator
|
||||
tokenGen influxdb.TokenGenerator
|
||||
|
||||
disableAuthorizationsForMaxPermissions func(context.Context) bool
|
||||
}
|
||||
|
||||
// NewService creates a new session service
|
||||
func NewService(store *Storage, userService influxdb.UserService, urmService influxdb.UserResourceMappingService, authSvc influxdb.AuthorizationService, sessionLength time.Duration) *Service {
|
||||
if sessionLength <= 0 {
|
||||
sessionLength = time.Hour
|
||||
}
|
||||
return &Service{
|
||||
store: store,
|
||||
userService: userService,
|
||||
urmService: urmService,
|
||||
authService: authSvc,
|
||||
sessionLength: sessionLength,
|
||||
idGen: snowflake.NewIDGenerator(),
|
||||
tokenGen: rand.NewTokenGenerator(64),
|
||||
disableAuthorizationsForMaxPermissions: func(context.Context) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxPermissionFunc sets the useAuthorizationsForMaxPermissions function
|
||||
// which can trigger whether or not max permissions uses the users authorizations
|
||||
// to derive maximum permissions.
|
||||
func (s *Service) WithMaxPermissionFunc(fn func(context.Context) bool) {
|
||||
s.disableAuthorizationsForMaxPermissions = fn
|
||||
}
|
||||
|
||||
// FindSession finds a session based on the session key
|
||||
func (s *Service) FindSession(ctx context.Context, key string) (*influxdb.Session, error) {
|
||||
session, err := s.store.FindSessionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: We want to be able to store permissions in the session
|
||||
// but the contract provided by urm's doesn't give us enough information to quickly repopulate our
|
||||
// session permissions on updates so we are required to pull the permissions every time we find the session.
|
||||
permissions, err := s.getPermissionSet(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.Permissions = permissions
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// ExpireSession removes a session from the system
|
||||
func (s *Service) ExpireSession(ctx context.Context, key string) error {
|
||||
session, err := s.store.FindSessionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.store.DeleteSession(ctx, session.ID)
|
||||
}
|
||||
|
||||
// CreateSession
|
||||
func (s *Service) CreateSession(ctx context.Context, user string) (*influxdb.Session, error) {
|
||||
u, err := s.userService.FindUser(ctx, influxdb.UserFilter{
|
||||
Name: &user,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := s.tokenGen.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// for now we are not storing the permissions because we need to pull them every time we find
|
||||
// so we might as well keep the session stored small
|
||||
now := time.Now()
|
||||
session := &influxdb.Session{
|
||||
ID: s.idGen.ID(),
|
||||
Key: token,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.sessionLength),
|
||||
UserID: u.ID,
|
||||
}
|
||||
|
||||
return session, s.store.CreateSession(ctx, session)
|
||||
}
|
||||
|
||||
// RenewSession update the sessions expiration time
|
||||
func (s *Service) RenewSession(ctx context.Context, session *influxdb.Session, newExpiration time.Time) error {
|
||||
if session == nil {
|
||||
return &influxdb.Error{
|
||||
Msg: "session is nil",
|
||||
}
|
||||
}
|
||||
return s.store.RefreshSession(ctx, session.ID, newExpiration)
|
||||
}
|
||||
|
||||
func (s *Service) getPermissionSet(ctx context.Context, uid influxdb.ID) ([]influxdb.Permission, error) {
|
||||
|
||||
mappings, _, err := s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Limit: 100})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
permissions, err := permissionFromMapping(mappings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(mappings) == 100 {
|
||||
// if we got 100 mappings we probably need to pull more pages
|
||||
// account for paginated results
|
||||
for i := len(mappings); len(mappings) > 0; i += len(mappings) {
|
||||
mappings, _, err = s.urmService.FindUserResourceMappings(ctx, influxdb.UserResourceMappingFilter{UserID: uid}, influxdb.FindOptions{Offset: i, Limit: 100})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pms, err := permissionFromMapping(mappings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
permissions = append(permissions, pms...)
|
||||
}
|
||||
}
|
||||
|
||||
if !s.disableAuthorizationsForMaxPermissions(ctx) {
|
||||
as, _, err := s.authService.FindAuthorizations(ctx, influxdb.AuthorizationFilter{UserID: &uid})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, a := range as {
|
||||
permissions = append(permissions, a.Permissions...)
|
||||
}
|
||||
}
|
||||
|
||||
permissions = append(permissions, influxdb.MePermissions(uid)...)
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
func permissionFromMapping(mappings []*influxdb.UserResourceMapping) ([]influxdb.Permission, error) {
|
||||
ps := make([]influxdb.Permission, 0, len(mappings))
|
||||
for _, m := range mappings {
|
||||
p, err := m.ToPermissions()
|
||||
if err != nil {
|
||||
return nil, &influxdb.Error{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
ps = append(ps, p...)
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/inmem"
|
||||
"github.com/influxdata/influxdb/v2/mock"
|
||||
"github.com/influxdata/influxdb/v2/tenant"
|
||||
influxdbtesting "github.com/influxdata/influxdb/v2/testing"
|
||||
)
|
||||
|
||||
func TestSessionService(t *testing.T) {
|
||||
influxdbtesting.SessionService(initSessionService, t)
|
||||
}
|
||||
|
||||
func initSessionService(f influxdbtesting.SessionFields, t *testing.T) (influxdb.SessionService, string, func()) {
|
||||
ss := NewStorage(inmem.NewSessionStore())
|
||||
ts, _ := tenant.NewStore(inmem.NewKVStore())
|
||||
ten := tenant.NewService(ts)
|
||||
svc := NewService(ss, ten, ten, &mock.AuthorizationService{
|
||||
FindAuthorizationsFn: func(context.Context, influxdb.AuthorizationFilter, ...influxdb.FindOptions) ([]*influxdb.Authorization, int, error) {
|
||||
return []*influxdb.Authorization{}, 0, nil
|
||||
},
|
||||
}, time.Minute)
|
||||
|
||||
if f.IDGenerator != nil {
|
||||
svc.idGen = f.IDGenerator
|
||||
}
|
||||
|
||||
if f.TokenGenerator != nil {
|
||||
svc.tokenGen = f.TokenGenerator
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, u := range f.Users {
|
||||
if err := ten.CreateUser(ctx, u); err != nil {
|
||||
t.Fatalf("failed to populate users")
|
||||
}
|
||||
}
|
||||
for _, s := range f.Sessions {
|
||||
if err := ss.CreateSession(ctx, s); err != nil {
|
||||
t.Fatalf("failed to populate sessions")
|
||||
}
|
||||
}
|
||||
return svc, "session", func() {}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
Set(key, val string, expireAt time.Time) error
|
||||
Get(key string) (string, error)
|
||||
Delete(key string) error
|
||||
ExpireAt(key string, expireAt time.Time) error
|
||||
}
|
||||
|
||||
var storePrefix = "sessionsv1/"
|
||||
var storeIndex = "sessionsindexv1/"
|
||||
|
||||
// Storage is a store translation layer between the data storage unit and the
|
||||
// service layer.
|
||||
type Storage struct {
|
||||
store Store
|
||||
}
|
||||
|
||||
// NewStorage creates a new storage system
|
||||
func NewStorage(s Store) *Storage {
|
||||
return &Storage{s}
|
||||
}
|
||||
|
||||
// FindSessionByKey use a given key to retrieve the stored session
|
||||
func (s *Storage) FindSessionByKey(ctx context.Context, key string) (*influxdb.Session, error) {
|
||||
val, err := s.store.Get(sessionIndexKey(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
return nil, &influxdb.Error{
|
||||
Code: influxdb.ENotFound,
|
||||
Msg: influxdb.ErrSessionNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
id, err := influxdb.IDFromString(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.FindSessionByID(ctx, *id)
|
||||
}
|
||||
|
||||
// FindSessionByID use a provided id to retrieve the stored session
|
||||
func (s *Storage) FindSessionByID(ctx context.Context, id influxdb.ID) (*influxdb.Session, error) {
|
||||
val, err := s.store.Get(storePrefix + id.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val == "" {
|
||||
return nil, &influxdb.Error{
|
||||
Code: influxdb.ENotFound,
|
||||
Msg: influxdb.ErrSessionNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
session := &influxdb.Session{}
|
||||
return session, json.Unmarshal([]byte(val), session)
|
||||
}
|
||||
|
||||
// CreateSession creates a new session
|
||||
func (s *Storage) CreateSession(ctx context.Context, session *influxdb.Session) error {
|
||||
// create session
|
||||
sessionBytes, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// use a minute time just so the session will expire if we fail to set the expiration later
|
||||
sessionID := sessionID(session.ID)
|
||||
if err := s.store.Set(sessionID, string(sessionBytes), session.ExpiresAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create index
|
||||
indexKey := sessionIndexKey(session.Key)
|
||||
if err := s.store.Set(indexKey, session.ID.String(), session.ExpiresAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshSession updates the expiration time of a session.
|
||||
func (s *Storage) RefreshSession(ctx context.Context, id influxdb.ID, expireAt time.Time) error {
|
||||
session, err := s.FindSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if expireAt.Before(session.ExpiresAt) {
|
||||
// no need to recreate the session if we aren't extending the expiration
|
||||
return nil
|
||||
}
|
||||
|
||||
session.ExpiresAt = expireAt
|
||||
return s.CreateSession(ctx, session)
|
||||
}
|
||||
|
||||
// DeleteSession removes the session and index from storage
|
||||
func (s *Storage) DeleteSession(ctx context.Context, id influxdb.ID) error {
|
||||
session, err := s.FindSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.store.Delete(sessionID(session.ID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.store.Delete(sessionIndexKey(session.Key)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sessionID(id influxdb.ID) string {
|
||||
return storePrefix + id.String()
|
||||
}
|
||||
|
||||
func sessionIndexKey(key string) string {
|
||||
return storeIndex + key
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package session_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/inmem"
|
||||
"github.com/influxdata/influxdb/v2/session"
|
||||
)
|
||||
|
||||
func TestSessionStore(t *testing.T) {
|
||||
driver := func() session.Store {
|
||||
return inmem.NewSessionStore()
|
||||
}
|
||||
|
||||
expected := &influxdb.Session{
|
||||
ID: 1,
|
||||
Key: "2",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
simpleSetup := func(t *testing.T, store *session.Storage) {
|
||||
err := store.CreateSession(
|
||||
context.Background(),
|
||||
expected,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
st := []struct {
|
||||
name string
|
||||
setup func(*testing.T, *session.Storage)
|
||||
update func(*testing.T, *session.Storage)
|
||||
results func(*testing.T, *session.Storage)
|
||||
}{
|
||||
{
|
||||
name: "create",
|
||||
setup: simpleSetup,
|
||||
results: func(t *testing.T, store *session.Storage) {
|
||||
session, err := store.FindSessionByID(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(session, expected) {
|
||||
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get",
|
||||
setup: simpleSetup,
|
||||
results: func(t *testing.T, store *session.Storage) {
|
||||
session, err := store.FindSessionByID(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(session, expected) {
|
||||
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
|
||||
}
|
||||
|
||||
session, err = store.FindSessionByKey(context.Background(), "2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(session, expected) {
|
||||
t.Fatalf("expected identical sessions: \n%+v\n%+v", session, expected)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete",
|
||||
setup: simpleSetup,
|
||||
update: func(t *testing.T, store *session.Storage) {
|
||||
err := store.DeleteSession(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
results: func(t *testing.T, store *session.Storage) {
|
||||
session, err := store.FindSessionByID(context.Background(), 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on deleted session but got none")
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
t.Fatal("got a session when none should have existed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, testScenario := range st {
|
||||
t.Run(testScenario.name, func(t *testing.T) {
|
||||
ss := session.NewStorage(driver())
|
||||
|
||||
// setup
|
||||
if testScenario.setup != nil {
|
||||
testScenario.setup(t, ss)
|
||||
}
|
||||
|
||||
// update
|
||||
if testScenario.update != nil {
|
||||
testScenario.update(t, ss)
|
||||
}
|
||||
|
||||
// results
|
||||
if testScenario.results != nil {
|
||||
testScenario.results(t, ss)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -326,7 +326,7 @@ func (as *AnalyticalStorage) RetryRun(ctx context.Context, taskID, runID influxd
|
|||
return run, err
|
||||
}
|
||||
|
||||
// try finding the run (in our system or underlieing)
|
||||
// try finding the run (in our system or underlying)
|
||||
run, err = as.FindRunByID(ctx, taskID, runID)
|
||||
if err != nil {
|
||||
return run, err
|
||||
|
|
|
@ -24,12 +24,12 @@ func NewAuthedOnboardSvc(s influxdb.OnboardingService) *AuthedOnboardSvc {
|
|||
}
|
||||
}
|
||||
|
||||
// IsOnboarding pass through. this is handled by the underlieing service layer
|
||||
// IsOnboarding pass through. this is handled by the underlying service layer
|
||||
func (s *AuthedOnboardSvc) IsOnboarding(ctx context.Context) (bool, error) {
|
||||
return s.s.IsOnboarding(ctx)
|
||||
}
|
||||
|
||||
// OnboardInitialUser pass through. this is handled by the underlieing service layer
|
||||
// OnboardInitialUser pass through. this is handled by the underlying service layer
|
||||
func (s *AuthedOnboardSvc) OnboardInitialUser(ctx context.Context, req *influxdb.OnboardingRequest) (*influxdb.OnboardingResults, error) {
|
||||
return s.s.OnboardInitialUser(ctx, req)
|
||||
}
|
||||
|
|
|
@ -68,6 +68,11 @@ func (s *Store) setup() error {
|
|||
if _, err := tx.Bucket(urmBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Bucket(urmByUserIndexBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Bucket(organizationBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/influxdata/influxdb/v2"
|
||||
platform "github.com/influxdata/influxdb/v2"
|
||||
"github.com/influxdata/influxdb/v2/mock"
|
||||
)
|
||||
|
@ -274,12 +275,12 @@ func ExpireSession(
|
|||
diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t)
|
||||
|
||||
session, err := s.FindSession(ctx, tt.args.key)
|
||||
if err == nil {
|
||||
if err.Error() != influxdb.ErrSessionExpired && err.Error() != influxdb.ErrSessionNotFound {
|
||||
t.Errorf("expected session to be expired got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" {
|
||||
t.Errorf("session is different -got/+want\ndiff %s", diff)
|
||||
if session != nil {
|
||||
t.Errorf("expected a nil session but got: %v", session)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue