diff --git a/oauth2/auth.go b/oauth2/auth.go index ed307bfaa..fc132eab2 100644 --- a/oauth2/auth.go +++ b/oauth2/auth.go @@ -56,12 +56,12 @@ func AuthorizedToken(auth Authenticator, te TokenExtractor, logger chronograf.Lo token, err := te.Extract(r) if err != nil { // Happens when Provider okays authentication, but Token is bad - log.Error("Unable to extract token") + log.Info("Unauthenticated user") w.WriteHeader(http.StatusUnauthorized) return } // We do not check the validity of the principal. Those - // server further down the chain should do so. + // served further down the chain should do so. principal, err := auth.Authenticate(r.Context(), token) if err != nil { log.Error("Invalid token") diff --git a/oauth2/auth_test.go b/oauth2/auth_test.go index 8e2011c43..40d2ec36f 100644 --- a/oauth2/auth_test.go +++ b/oauth2/auth_test.go @@ -154,10 +154,12 @@ func TestAuthorizedToken(t *testing.T) { AuthErr: errors.New("error"), }, { - Desc: "Authorized ok", - Code: http.StatusOK, - Principal: "Principal Strickland", - Expected: "Principal Strickland", + Desc: "Authorized ok", + Code: http.StatusOK, + Principal: oauth2.Principal{ + Subject: "Principal Strickland", + }, + Expected: "Principal Strickland", }, } for _, test := range tests { diff --git a/oauth2/jwt.go b/oauth2/jwt.go index fd10345fe..fb54b62f9 100644 --- a/oauth2/jwt.go +++ b/oauth2/jwt.go @@ -61,17 +61,20 @@ func (j *JWT) Authenticate(ctx context.Context, jwtToken string) (Principal, err // 4. Check if subject is not empty token, err := gojwt.ParseWithClaims(jwtToken, &Claims{}, alg) if err != nil { - return "", err + return Principal{}, err } else if !token.Valid { - return "", err + return Principal{}, err } claims, ok := token.Claims.(*Claims) if !ok { - return "", fmt.Errorf("unable to convert claims to standard claims") + return Principal{}, fmt.Errorf("unable to convert claims to standard claims") } - return Principal(claims.Subject), nil + return Principal{ + Subject: claims.Subject, + Issuer: claims.Issuer, + }, nil } // Token creates a signed JWT token from user that expires at Now + duration @@ -81,7 +84,8 @@ func (j *JWT) Token(ctx context.Context, user Principal, duration time.Duration) now := j.Now().UTC() claims := &Claims{ gojwt.StandardClaims{ - Subject: string(user), + Subject: user.Subject, + Issuer: user.Issuer, ExpiresAt: now.Add(duration).Unix(), IssuedAt: now.Unix(), NotBefore: now.Unix(), diff --git a/oauth2/jwt_test.go b/oauth2/jwt_test.go index ec81541ab..d65a82800 100644 --- a/oauth2/jwt_test.go +++ b/oauth2/jwt_test.go @@ -21,35 +21,45 @@ func TestAuthenticate(t *testing.T) { Desc: "Test bad jwt token", Secret: "secret", Token: "badtoken", - User: "", - Err: errors.New("token contains an invalid number of segments"), + User: oauth2.Principal{ + Subject: "", + }, + Err: errors.New("token contains an invalid number of segments"), }, { Desc: "Test valid jwt token", Secret: "secret", Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQ0MDB9._rZ4gOIei9PizHOABH6kLcJTA3jm8ls0YnDxtz1qeUI", - User: "/chronograf/v1/users/1", + User: oauth2.Principal{ + Subject: "/chronograf/v1/users/1", + }, }, { Desc: "Test expired jwt token", Secret: "secret", Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAxLCJuYmYiOi00NDY3NzQ0MDB9.vWXdm0-XQ_pW62yBpSISFFJN_yz0vqT9_INcUKTp5Q8", - User: "", - Err: errors.New("token is expired by 1s"), + User: oauth2.Principal{ + Subject: "", + }, + Err: errors.New("token is expired by 1s"), }, { Desc: "Test jwt token not before time", Secret: "secret", Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQzOTl9.TMGAhv57u1aosjc4ywKC7cElP1tKyQH7GmRF2ToAxlE", - User: "", - Err: errors.New("token is not valid yet"), + User: oauth2.Principal{ + Subject: "", + }, + Err: errors.New("token is not valid yet"), }, { Desc: "Test jwt with empty subject is invalid", Secret: "secret", Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOi00NDY3NzQ0MDAsImV4cCI6LTQ0Njc3NDQwMCwibmJmIjotNDQ2Nzc0NDAwfQ.gxsA6_Ei3s0f2I1TAtrrb8FmGiO25OqVlktlF_ylhX4", - User: "", - Err: errors.New("claim has no subject"), + User: oauth2.Principal{ + Subject: "", + }, + Err: errors.New("claim has no subject"), }, } for i, test := range tests { @@ -82,7 +92,10 @@ func TestToken(t *testing.T) { return time.Unix(-446774400, 0) }, } - if token, err := j.Token(context.Background(), oauth2.Principal("/chronograf/v1/users/1"), duration); err != nil { + p := oauth2.Principal{ + Subject: "/chronograf/v1/users/1", + } + if token, err := j.Token(context.Background(), p, duration); err != nil { t.Errorf("Error creating token for user: %v", err) } else if token != expected { t.Errorf("Error creating token; expected: %s actual: %s", "", token) diff --git a/oauth2/mux.go b/oauth2/mux.go index c4ae963eb..68e894a29 100644 --- a/oauth2/mux.go +++ b/oauth2/mux.go @@ -64,7 +64,10 @@ func (j *JWTMux) Login() http.Handler { // We'll give our users 10 minutes from this point to type in their github password. // If the callback is not received within 10 minutes, then authorization will fail. csrf := randomString(32) // 32 is not important... just long - state, err := j.Auth.Token(r.Context(), Principal(csrf), 10*time.Minute) + p := Principal{ + Subject: csrf, + } + state, err := j.Auth.Token(r.Context(), p, 10*time.Minute) // This is likely an internal server error if err != nil { j.Logger. @@ -122,8 +125,12 @@ func (j *JWTMux) Callback() http.Handler { return } + p := Principal{ + Subject: id, + Issuer: j.Provider.Name(), + } // We create an auth token that will be used by all other endpoints to validate the principal has a claim - authToken, err := j.Auth.Token(r.Context(), Principal(id), j.cookie.Duration) + authToken, err := j.Auth.Token(r.Context(), p, j.cookie.Duration) if err != nil { log.Error("Unable to create cookie auth token ", err.Error()) http.Redirect(w, r, j.FailureURL, http.StatusTemporaryRedirect) diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index d29046d49..6cf73f9ae 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -14,7 +14,7 @@ const ( // PrincipalKey is used to pass principal // via context.Context to request-scoped // functions. - PrincipalKey Principal = "principal" + PrincipalKey string = "principal" ) var ( @@ -25,7 +25,10 @@ var ( /* Types */ // Principal is any entity that can be authenticated -type Principal string +type Principal struct { + Subject string + Issuer string +} /* Interfaces */ diff --git a/server/logout.go b/server/logout.go new file mode 100644 index 000000000..3e827c2af --- /dev/null +++ b/server/logout.go @@ -0,0 +1,21 @@ +package server + +import "net/http" + +// Logout chooses the correct provider logout route and redirects to it +func Logout(nextURL string, routes AuthRoutes) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + principal, err := getPrincipal(ctx) + if err != nil { + http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect) + return + } + route, ok := routes.Lookup(principal.Issuer) + if !ok { + http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect) + return + } + http.Redirect(w, r, route.Logout, http.StatusTemporaryRedirect) + } +} diff --git a/server/mux.go b/server/mux.go index 0a94f36b5..abbc51c76 100644 --- a/server/mux.go +++ b/server/mux.go @@ -42,8 +42,8 @@ func (m *MuxOpts) UseGoogle() bool { return m.TokenSecret != "" && m.GoogleClientID != "" && m.GoogleClientSecret != "" && m.PublicURL != "" } -func (m *MuxOpts) Routes() []AuthRoute { - routes := []AuthRoute{} +func (m *MuxOpts) Routes() AuthRoutes { + routes := AuthRoutes{} if m.UseGithub() { routes = append(routes, NewGithubRoute()) } @@ -148,6 +148,11 @@ func NewMux(opts MuxOpts, service Service) http.Handler { /* Authentication */ if opts.UseAuth { + // Create middleware to redirect to the appropriate provider logout + targetURL := "/" + router.GET("/oauth/logout", Logout(targetURL, authRoutes)) + + // Encapsulate the router with OAuth2 auth := AuthAPI(opts, router) return Logger(opts.Logger, auth) } @@ -157,6 +162,7 @@ func NewMux(opts MuxOpts, service Service) http.Handler { } // AuthAPI adds the OAuth routes if auth is enabled. +// TODO: this function is not great. Would be good if providers added their routes. func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler { auth := oauth2.NewJWT(opts.TokenSecret) if opts.UseGithub() { @@ -192,7 +198,7 @@ func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler { tokenMiddleware := oauth2.AuthorizedToken(&auth, &oauth2.CookieExtractor{Name: "session"}, opts.Logger, router) // Wrap the API with token validation middleware. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") { + if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") || r.URL.Path == "/oauth/logout" { tokenMiddleware.ServeHTTP(w, r) return } diff --git a/server/routes.go b/server/routes.go index 4156ef28a..4dbef40b5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,6 +15,19 @@ type AuthRoute struct { Callback string `json:"callback"` // Callback is the route the provider calls to exchange the code/state } +// AuthRoutes contains all OAuth2 provider routes. +type AuthRoutes []AuthRoute + +// Lookup searches all the routes for a specific provider +func (r *AuthRoutes) Lookup(provider string) (AuthRoute, bool) { + for _, route := range *r { + if route.Name == provider { + return route, true + } + } + return AuthRoute{}, false +} + type getRoutesResponse struct { Layouts string `json:"layouts"` // Location of the layouts endpoint Mappings string `json:"mappings"` // Location of the application mappings endpoint diff --git a/server/users.go b/server/users.go index 624cce91c..9426ce0f2 100644 --- a/server/users.go +++ b/server/users.go @@ -141,11 +141,23 @@ func ValidUserRequest(s *chronograf.User) error { } func getEmail(ctx context.Context) (string, error) { - principal := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal) - if principal == "" { + principal, err := getPrincipal(ctx) + if err != nil { + return "", err + } + if principal.Subject == "" { return "", fmt.Errorf("Token not found") } - return string(principal), nil + return principal.Subject, nil +} + +func getPrincipal(ctx context.Context) (oauth2.Principal, error) { + principal, ok := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal) + if !ok { + return oauth2.Principal{}, fmt.Errorf("Token not found") + } + + return principal, nil } // Me does a findOrCreate based on the email in the context