Add unified OAuth2 logout route redirecting to provider logout
Signed-off-by: Tim Raymond <tim@timraymond.com>pull/922/head
parent
4f89e7c4a0
commit
2017944b68
|
@ -56,12 +56,12 @@ func AuthorizedToken(auth Authenticator, te TokenExtractor, logger chronograf.Lo
|
|||
token, err := te.Extract(r)
|
||||
if err != nil {
|
||||
// Happens when Provider okays authentication, but Token is bad
|
||||
log.Error("Unable to extract token")
|
||||
log.Info("Unauthenticated user")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// We do not check the validity of the principal. Those
|
||||
// server further down the chain should do so.
|
||||
// served further down the chain should do so.
|
||||
principal, err := auth.Authenticate(r.Context(), token)
|
||||
if err != nil {
|
||||
log.Error("Invalid token")
|
||||
|
|
|
@ -154,10 +154,12 @@ func TestAuthorizedToken(t *testing.T) {
|
|||
AuthErr: errors.New("error"),
|
||||
},
|
||||
{
|
||||
Desc: "Authorized ok",
|
||||
Code: http.StatusOK,
|
||||
Principal: "Principal Strickland",
|
||||
Expected: "Principal Strickland",
|
||||
Desc: "Authorized ok",
|
||||
Code: http.StatusOK,
|
||||
Principal: oauth2.Principal{
|
||||
Subject: "Principal Strickland",
|
||||
},
|
||||
Expected: "Principal Strickland",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
|
|
|
@ -61,17 +61,20 @@ func (j *JWT) Authenticate(ctx context.Context, jwtToken string) (Principal, err
|
|||
// 4. Check if subject is not empty
|
||||
token, err := gojwt.ParseWithClaims(jwtToken, &Claims{}, alg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return Principal{}, err
|
||||
} else if !token.Valid {
|
||||
return "", err
|
||||
return Principal{}, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unable to convert claims to standard claims")
|
||||
return Principal{}, fmt.Errorf("unable to convert claims to standard claims")
|
||||
}
|
||||
|
||||
return Principal(claims.Subject), nil
|
||||
return Principal{
|
||||
Subject: claims.Subject,
|
||||
Issuer: claims.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Token creates a signed JWT token from user that expires at Now + duration
|
||||
|
@ -81,7 +84,8 @@ func (j *JWT) Token(ctx context.Context, user Principal, duration time.Duration)
|
|||
now := j.Now().UTC()
|
||||
claims := &Claims{
|
||||
gojwt.StandardClaims{
|
||||
Subject: string(user),
|
||||
Subject: user.Subject,
|
||||
Issuer: user.Issuer,
|
||||
ExpiresAt: now.Add(duration).Unix(),
|
||||
IssuedAt: now.Unix(),
|
||||
NotBefore: now.Unix(),
|
||||
|
|
|
@ -21,35 +21,45 @@ func TestAuthenticate(t *testing.T) {
|
|||
Desc: "Test bad jwt token",
|
||||
Secret: "secret",
|
||||
Token: "badtoken",
|
||||
User: "",
|
||||
Err: errors.New("token contains an invalid number of segments"),
|
||||
User: oauth2.Principal{
|
||||
Subject: "",
|
||||
},
|
||||
Err: errors.New("token contains an invalid number of segments"),
|
||||
},
|
||||
{
|
||||
Desc: "Test valid jwt token",
|
||||
Secret: "secret",
|
||||
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQ0MDB9._rZ4gOIei9PizHOABH6kLcJTA3jm8ls0YnDxtz1qeUI",
|
||||
User: "/chronograf/v1/users/1",
|
||||
User: oauth2.Principal{
|
||||
Subject: "/chronograf/v1/users/1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Desc: "Test expired jwt token",
|
||||
Secret: "secret",
|
||||
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAxLCJuYmYiOi00NDY3NzQ0MDB9.vWXdm0-XQ_pW62yBpSISFFJN_yz0vqT9_INcUKTp5Q8",
|
||||
User: "",
|
||||
Err: errors.New("token is expired by 1s"),
|
||||
User: oauth2.Principal{
|
||||
Subject: "",
|
||||
},
|
||||
Err: errors.New("token is expired by 1s"),
|
||||
},
|
||||
{
|
||||
Desc: "Test jwt token not before time",
|
||||
Secret: "secret",
|
||||
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQzOTl9.TMGAhv57u1aosjc4ywKC7cElP1tKyQH7GmRF2ToAxlE",
|
||||
User: "",
|
||||
Err: errors.New("token is not valid yet"),
|
||||
User: oauth2.Principal{
|
||||
Subject: "",
|
||||
},
|
||||
Err: errors.New("token is not valid yet"),
|
||||
},
|
||||
{
|
||||
Desc: "Test jwt with empty subject is invalid",
|
||||
Secret: "secret",
|
||||
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOi00NDY3NzQ0MDAsImV4cCI6LTQ0Njc3NDQwMCwibmJmIjotNDQ2Nzc0NDAwfQ.gxsA6_Ei3s0f2I1TAtrrb8FmGiO25OqVlktlF_ylhX4",
|
||||
User: "",
|
||||
Err: errors.New("claim has no subject"),
|
||||
User: oauth2.Principal{
|
||||
Subject: "",
|
||||
},
|
||||
Err: errors.New("claim has no subject"),
|
||||
},
|
||||
}
|
||||
for i, test := range tests {
|
||||
|
@ -82,7 +92,10 @@ func TestToken(t *testing.T) {
|
|||
return time.Unix(-446774400, 0)
|
||||
},
|
||||
}
|
||||
if token, err := j.Token(context.Background(), oauth2.Principal("/chronograf/v1/users/1"), duration); err != nil {
|
||||
p := oauth2.Principal{
|
||||
Subject: "/chronograf/v1/users/1",
|
||||
}
|
||||
if token, err := j.Token(context.Background(), p, duration); err != nil {
|
||||
t.Errorf("Error creating token for user: %v", err)
|
||||
} else if token != expected {
|
||||
t.Errorf("Error creating token; expected: %s actual: %s", "", token)
|
||||
|
|
|
@ -64,7 +64,10 @@ func (j *JWTMux) Login() http.Handler {
|
|||
// We'll give our users 10 minutes from this point to type in their github password.
|
||||
// If the callback is not received within 10 minutes, then authorization will fail.
|
||||
csrf := randomString(32) // 32 is not important... just long
|
||||
state, err := j.Auth.Token(r.Context(), Principal(csrf), 10*time.Minute)
|
||||
p := Principal{
|
||||
Subject: csrf,
|
||||
}
|
||||
state, err := j.Auth.Token(r.Context(), p, 10*time.Minute)
|
||||
// This is likely an internal server error
|
||||
if err != nil {
|
||||
j.Logger.
|
||||
|
@ -122,8 +125,12 @@ func (j *JWTMux) Callback() http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
p := Principal{
|
||||
Subject: id,
|
||||
Issuer: j.Provider.Name(),
|
||||
}
|
||||
// We create an auth token that will be used by all other endpoints to validate the principal has a claim
|
||||
authToken, err := j.Auth.Token(r.Context(), Principal(id), j.cookie.Duration)
|
||||
authToken, err := j.Auth.Token(r.Context(), p, j.cookie.Duration)
|
||||
if err != nil {
|
||||
log.Error("Unable to create cookie auth token ", err.Error())
|
||||
http.Redirect(w, r, j.FailureURL, http.StatusTemporaryRedirect)
|
||||
|
|
|
@ -14,7 +14,7 @@ const (
|
|||
// PrincipalKey is used to pass principal
|
||||
// via context.Context to request-scoped
|
||||
// functions.
|
||||
PrincipalKey Principal = "principal"
|
||||
PrincipalKey string = "principal"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -25,7 +25,10 @@ var (
|
|||
/* Types */
|
||||
|
||||
// Principal is any entity that can be authenticated
|
||||
type Principal string
|
||||
type Principal struct {
|
||||
Subject string
|
||||
Issuer string
|
||||
}
|
||||
|
||||
/* Interfaces */
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Logout chooses the correct provider logout route and redirects to it
|
||||
func Logout(nextURL string, routes AuthRoutes) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
principal, err := getPrincipal(ctx)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
route, ok := routes.Lookup(principal.Issuer)
|
||||
if !ok {
|
||||
http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, route.Logout, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
|
@ -42,8 +42,8 @@ func (m *MuxOpts) UseGoogle() bool {
|
|||
return m.TokenSecret != "" && m.GoogleClientID != "" && m.GoogleClientSecret != "" && m.PublicURL != ""
|
||||
}
|
||||
|
||||
func (m *MuxOpts) Routes() []AuthRoute {
|
||||
routes := []AuthRoute{}
|
||||
func (m *MuxOpts) Routes() AuthRoutes {
|
||||
routes := AuthRoutes{}
|
||||
if m.UseGithub() {
|
||||
routes = append(routes, NewGithubRoute())
|
||||
}
|
||||
|
@ -148,6 +148,11 @@ func NewMux(opts MuxOpts, service Service) http.Handler {
|
|||
|
||||
/* Authentication */
|
||||
if opts.UseAuth {
|
||||
// Create middleware to redirect to the appropriate provider logout
|
||||
targetURL := "/"
|
||||
router.GET("/oauth/logout", Logout(targetURL, authRoutes))
|
||||
|
||||
// Encapsulate the router with OAuth2
|
||||
auth := AuthAPI(opts, router)
|
||||
return Logger(opts.Logger, auth)
|
||||
}
|
||||
|
@ -157,6 +162,7 @@ func NewMux(opts MuxOpts, service Service) http.Handler {
|
|||
}
|
||||
|
||||
// AuthAPI adds the OAuth routes if auth is enabled.
|
||||
// TODO: this function is not great. Would be good if providers added their routes.
|
||||
func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler {
|
||||
auth := oauth2.NewJWT(opts.TokenSecret)
|
||||
if opts.UseGithub() {
|
||||
|
@ -192,7 +198,7 @@ func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler {
|
|||
tokenMiddleware := oauth2.AuthorizedToken(&auth, &oauth2.CookieExtractor{Name: "session"}, opts.Logger, router)
|
||||
// Wrap the API with token validation middleware.
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") {
|
||||
if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") || r.URL.Path == "/oauth/logout" {
|
||||
tokenMiddleware.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -15,6 +15,19 @@ type AuthRoute struct {
|
|||
Callback string `json:"callback"` // Callback is the route the provider calls to exchange the code/state
|
||||
}
|
||||
|
||||
// AuthRoutes contains all OAuth2 provider routes.
|
||||
type AuthRoutes []AuthRoute
|
||||
|
||||
// Lookup searches all the routes for a specific provider
|
||||
func (r *AuthRoutes) Lookup(provider string) (AuthRoute, bool) {
|
||||
for _, route := range *r {
|
||||
if route.Name == provider {
|
||||
return route, true
|
||||
}
|
||||
}
|
||||
return AuthRoute{}, false
|
||||
}
|
||||
|
||||
type getRoutesResponse struct {
|
||||
Layouts string `json:"layouts"` // Location of the layouts endpoint
|
||||
Mappings string `json:"mappings"` // Location of the application mappings endpoint
|
||||
|
|
|
@ -141,11 +141,23 @@ func ValidUserRequest(s *chronograf.User) error {
|
|||
}
|
||||
|
||||
func getEmail(ctx context.Context) (string, error) {
|
||||
principal := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal)
|
||||
if principal == "" {
|
||||
principal, err := getPrincipal(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if principal.Subject == "" {
|
||||
return "", fmt.Errorf("Token not found")
|
||||
}
|
||||
return string(principal), nil
|
||||
return principal.Subject, nil
|
||||
}
|
||||
|
||||
func getPrincipal(ctx context.Context) (oauth2.Principal, error) {
|
||||
principal, ok := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal)
|
||||
if !ok {
|
||||
return oauth2.Principal{}, fmt.Errorf("Token not found")
|
||||
}
|
||||
|
||||
return principal, nil
|
||||
}
|
||||
|
||||
// Me does a findOrCreate based on the email in the context
|
||||
|
|
Loading…
Reference in New Issue