From 15a32a086043d0c34e0502708ab806b14b73165a Mon Sep 17 00:00:00 2001 From: Daniel Moran Date: Wed, 6 Oct 2021 22:37:02 -0400 Subject: [PATCH] refactor: consolidate session-handling code (#22626) --- http/authentication_middleware.go | 5 +- http/authentication_test.go | 5 +- http/middleware.go | 4 +- http/session_handler.go | 187 ------------------------------ http/session_test.go | 105 ----------------- session/http_server.go | 11 +- 6 files changed, 14 insertions(+), 303 deletions(-) delete mode 100644 http/session_handler.go delete mode 100644 http/session_test.go diff --git a/http/authentication_middleware.go b/http/authentication_middleware.go index f53f2d3006..3b9bb4be52 100644 --- a/http/authentication_middleware.go +++ b/http/authentication_middleware.go @@ -12,6 +12,7 @@ import ( platcontext "github.com/influxdata/influxdb/v2/context" "github.com/influxdata/influxdb/v2/jsonweb" errors2 "github.com/influxdata/influxdb/v2/kit/platform/errors" + "github.com/influxdata/influxdb/v2/session" "github.com/opentracing/opentracing-go" "go.uber.org/zap" ) @@ -59,7 +60,7 @@ const ( // ProbeAuthScheme probes the http request for the requests for token or cookie session. func ProbeAuthScheme(r *http.Request) (string, error) { _, tokenErr := GetToken(r) - _, sessErr := decodeCookieSession(r.Context(), r) + _, sessErr := session.DecodeCookieSession(r.Context(), r) if tokenErr != nil && sessErr != nil { return "", fmt.Errorf("token required") @@ -162,7 +163,7 @@ func (h *AuthenticationHandler) extractAuthorization(ctx context.Context, r *htt } func (h *AuthenticationHandler) extractSession(ctx context.Context, r *http.Request) (*platform.Session, error) { - k, err := decodeCookieSession(ctx, r) + k, err := session.DecodeCookieSession(ctx, r) if err != nil { return nil, err } diff --git a/http/authentication_test.go b/http/authentication_test.go index 84247f027c..f4f853c1be 100644 --- a/http/authentication_test.go +++ b/http/authentication_test.go @@ -15,6 +15,7 @@ import ( "github.com/influxdata/influxdb/v2/kit/platform" kithttp "github.com/influxdata/influxdb/v2/kit/transport/http" "github.com/influxdata/influxdb/v2/mock" + "github.com/influxdata/influxdb/v2/session" "go.uber.org/zap/zaptest" ) @@ -233,7 +234,7 @@ func TestAuthenticationHandler(t *testing.T) { r := httptest.NewRequest("POST", "http://any.url", nil) if tt.args.session != "" { - platformhttp.SetCookieSession(tt.args.session, r) + session.SetCookieSession(tt.args.session, r) } if tt.args.token != "" { @@ -296,7 +297,7 @@ func TestProbeAuthScheme(t *testing.T) { r := httptest.NewRequest("POST", "http://any.url", nil) if tt.args.session != "" { - platformhttp.SetCookieSession(tt.args.session, r) + session.SetCookieSession(tt.args.session, r) } if tt.args.token != "" { diff --git a/http/middleware.go b/http/middleware.go index f121c84905..941b269ea4 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -133,8 +133,8 @@ const ( // TODO(@jsteenb2): make this a stronger type that handlers can register routes that should not be logged. var blacklistEndpoints = map[string]isValidMethodFn{ - prefixSignIn: ignoreMethod(), - prefixSignOut: ignoreMethod(), + "/api/v2/signin": ignoreMethod(), + "/api/v2/signout": ignoreMethod(), prefixMe: ignoreMethod(), mePasswordPath: ignoreMethod(), usersPasswordPath: ignoreMethod(), diff --git a/http/session_handler.go b/http/session_handler.go deleted file mode 100644 index 582c7b6fe3..0000000000 --- a/http/session_handler.go +++ /dev/null @@ -1,187 +0,0 @@ -package http - -import ( - "context" - "net/http" - - "github.com/influxdata/httprouter" - platform "github.com/influxdata/influxdb/v2" - "github.com/influxdata/influxdb/v2/kit/platform/errors" - "go.uber.org/zap" -) - -const ( - prefixSignIn = "/api/v2/signin" - prefixSignOut = "/api/v2/signout" -) - -// SessionBackend is all services and associated parameters required to construct -// the SessionHandler. -type SessionBackend struct { - log *zap.Logger - errors.HTTPErrorHandler - - PasswordsService platform.PasswordsService - SessionService platform.SessionService - UserService platform.UserService -} - -// NewSessionBackend creates a new SessionBackend with associated logger. -func NewSessionBackend(log *zap.Logger, b *APIBackend) *SessionBackend { - return &SessionBackend{ - HTTPErrorHandler: b.HTTPErrorHandler, - log: log, - - PasswordsService: b.PasswordsService, - SessionService: b.SessionService, - UserService: b.UserService, - } -} - -// SessionHandler represents an HTTP API handler for authorizations. -type SessionHandler struct { - *httprouter.Router - errors.HTTPErrorHandler - log *zap.Logger - - PasswordsService platform.PasswordsService - SessionService platform.SessionService - UserService platform.UserService -} - -// NewSessionHandler returns a new instance of SessionHandler. -func NewSessionHandler(log *zap.Logger, b *SessionBackend) *SessionHandler { - h := &SessionHandler{ - Router: NewRouter(b.HTTPErrorHandler), - HTTPErrorHandler: b.HTTPErrorHandler, - log: log, - - PasswordsService: b.PasswordsService, - SessionService: b.SessionService, - UserService: b.UserService, - } - - h.HandlerFunc("POST", prefixSignIn, h.handleSignin) - h.HandlerFunc("POST", prefixSignOut, h.handleSignout) - return h -} - -// handleSignin is the HTTP handler for the POST /signin route. -func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - req, decErr := decodeSigninRequest(ctx, r) - if decErr != nil { - UnauthorizedError(ctx, h, w) - return - } - - u, err := h.UserService.FindUser(ctx, platform.UserFilter{ - Name: &req.Username, - }) - if err != nil { - UnauthorizedError(ctx, h, w) - return - } - - if err := h.PasswordsService.ComparePassword(ctx, u.ID, req.Password); err != nil { - // Don't log here, it should already be handled by the service - UnauthorizedError(ctx, h, w) - return - } - - s, e := h.SessionService.CreateSession(ctx, req.Username) - if e != nil { - UnauthorizedError(ctx, h, w) - return - } - - encodeCookieSession(w, s) - w.WriteHeader(http.StatusNoContent) -} - -type signinRequest struct { - Username string - Password string -} - -func decodeSigninRequest(ctx context.Context, r *http.Request) (*signinRequest, *errors.Error) { - u, p, ok := r.BasicAuth() - if !ok { - return nil, &errors.Error{ - Code: errors.EInvalid, - Msg: "invalid basic auth", - } - } - - return &signinRequest{ - Username: u, - Password: p, - }, nil -} - -// handleSignout is the HTTP handler for the POST /signout route. -func (h *SessionHandler) handleSignout(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - req, err := decodeSignoutRequest(ctx, r) - if err != nil { - UnauthorizedError(ctx, h, w) - return - } - - if err := h.SessionService.ExpireSession(ctx, req.Key); err != nil { - UnauthorizedError(ctx, h, w) - return - } - - // TODO(desa): not sure what to do here maybe redirect? - w.WriteHeader(http.StatusNoContent) -} - -type signoutRequest struct { - Key string -} - -func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) { - key, err := decodeCookieSession(ctx, r) - if err != nil { - return nil, err - } - return &signoutRequest{ - Key: key, - }, nil -} - -const cookieSessionName = "session" - -func encodeCookieSession(w http.ResponseWriter, s *platform.Session) { - c := &http.Cookie{ - Name: cookieSessionName, - Value: s.Key, - } - - http.SetCookie(w, c) -} -func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) { - c, err := r.Cookie(cookieSessionName) - if err != nil { - return "", &errors.Error{ - Code: errors.EInvalid, - Err: err, - } - } - return c.Value, nil -} - -// SetCookieSession adds a cookie for the session to an http request -func SetCookieSession(key string, r *http.Request) { - c := &http.Cookie{ - Name: cookieSessionName, - Value: key, - Secure: true, - SameSite: http.SameSiteStrictMode, - } - - r.AddCookie(c) -} diff --git a/http/session_test.go b/http/session_test.go deleted file mode 100644 index 85f4556a65..0000000000 --- a/http/session_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package http - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" - - platform "github.com/influxdata/influxdb/v2" - platform2 "github.com/influxdata/influxdb/v2/kit/platform" - "github.com/influxdata/influxdb/v2/mock" - "go.uber.org/zap/zaptest" -) - -// NewMockSessionBackend returns a SessionBackend with mock services. -func NewMockSessionBackend(t *testing.T) *SessionBackend { - userSVC := mock.NewUserService() - userSVC.FindUserFn = func(_ context.Context, f platform.UserFilter) (*platform.User, error) { - return &platform.User{ID: 1}, nil - } - return &SessionBackend{ - log: zaptest.NewLogger(t), - - SessionService: mock.NewSessionService(), - PasswordsService: mock.NewPasswordsService(), - UserService: userSVC, - } -} - -func TestSessionHandler_handleSignin(t *testing.T) { - type fields struct { - PasswordsService platform.PasswordsService - SessionService platform.SessionService - } - type args struct { - user string - password string - } - type wants struct { - cookie string - code int - } - - tests := []struct { - name string - fields fields - args args - wants wants - }{ - { - name: "successful compare password", - fields: fields{ - SessionService: &mock.SessionService{ - CreateSessionFn: func(context.Context, string) (*platform.Session, error) { - return &platform.Session{ - ID: platform2.ID(0), - Key: "abc123xyz", - CreatedAt: time.Date(2018, 9, 26, 0, 0, 0, 0, time.UTC), - ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), - UserID: platform2.ID(1), - }, nil - }, - }, - PasswordsService: &mock.PasswordsService{ - ComparePasswordFn: func(context.Context, platform2.ID, string) error { - return nil - }, - }, - }, - args: args{ - user: "user1", - password: "supersecret", - }, - wants: wants{ - cookie: "session=abc123xyz", - code: http.StatusNoContent, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := NewMockSessionBackend(t) - b.PasswordsService = tt.fields.PasswordsService - b.SessionService = tt.fields.SessionService - h := NewSessionHandler(zaptest.NewLogger(t), b) - - w := httptest.NewRecorder() - r := httptest.NewRequest("POST", "http://localhost:8086/api/v2/signin", nil) - r.SetBasicAuth(tt.args.user, tt.args.password) - h.ServeHTTP(w, r) - - if got, want := w.Code, tt.wants.code; got != want { - t.Errorf("bad status code: got %d want %d", got, want) - } - - headers := w.Header() - cookie := headers.Get("Set-Cookie") - if got, want := cookie, tt.wants.cookie; got != want { - t.Errorf("expected session cookie to be set: got %q want %q", got, want) - } - }) - } -} diff --git a/session/http_server.go b/session/http_server.go index bab678453e..02c4c167fa 100644 --- a/session/http_server.go +++ b/session/http_server.go @@ -152,7 +152,7 @@ type signoutRequest struct { } func decodeSignoutRequest(ctx context.Context, r *http.Request) (*signoutRequest, error) { - key, err := decodeCookieSession(ctx, r) + key, err := DecodeCookieSession(ctx, r) if err != nil { return nil, err } @@ -173,7 +173,7 @@ func encodeCookieSession(w http.ResponseWriter, s *influxdb.Session) { http.SetCookie(w, c) } -func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) { +func DecodeCookieSession(ctx context.Context, r *http.Request) (string, error) { c, err := r.Cookie(cookieSessionName) if err != nil { return "", &errors.Error{ @@ -187,9 +187,10 @@ func decodeCookieSession(ctx context.Context, r *http.Request) (string, error) { // SetCookieSession adds a cookie for the session to an http request func SetCookieSession(key string, r *http.Request) { c := &http.Cookie{ - Name: cookieSessionName, - Value: key, - Secure: true, + Name: cookieSessionName, + Value: key, + Secure: true, + SameSite: http.SameSiteStrictMode, } r.AddCookie(c)