Add unified OAuth2 logout route redirecting to provider logout

Signed-off-by: Tim Raymond <tim@timraymond.com>
pull/922/head
Chris Goller 2017-02-15 16:28:17 -06:00
parent 4f89e7c4a0
commit 2017944b68
10 changed files with 112 additions and 31 deletions

View File

@ -56,12 +56,12 @@ func AuthorizedToken(auth Authenticator, te TokenExtractor, logger chronograf.Lo
token, err := te.Extract(r)
if err != nil {
// Happens when Provider okays authentication, but Token is bad
log.Error("Unable to extract token")
log.Info("Unauthenticated user")
w.WriteHeader(http.StatusUnauthorized)
return
}
// We do not check the validity of the principal. Those
// server further down the chain should do so.
// served further down the chain should do so.
principal, err := auth.Authenticate(r.Context(), token)
if err != nil {
log.Error("Invalid token")

View File

@ -154,10 +154,12 @@ func TestAuthorizedToken(t *testing.T) {
AuthErr: errors.New("error"),
},
{
Desc: "Authorized ok",
Code: http.StatusOK,
Principal: "Principal Strickland",
Expected: "Principal Strickland",
Desc: "Authorized ok",
Code: http.StatusOK,
Principal: oauth2.Principal{
Subject: "Principal Strickland",
},
Expected: "Principal Strickland",
},
}
for _, test := range tests {

View File

@ -61,17 +61,20 @@ func (j *JWT) Authenticate(ctx context.Context, jwtToken string) (Principal, err
// 4. Check if subject is not empty
token, err := gojwt.ParseWithClaims(jwtToken, &Claims{}, alg)
if err != nil {
return "", err
return Principal{}, err
} else if !token.Valid {
return "", err
return Principal{}, err
}
claims, ok := token.Claims.(*Claims)
if !ok {
return "", fmt.Errorf("unable to convert claims to standard claims")
return Principal{}, fmt.Errorf("unable to convert claims to standard claims")
}
return Principal(claims.Subject), nil
return Principal{
Subject: claims.Subject,
Issuer: claims.Issuer,
}, nil
}
// Token creates a signed JWT token from user that expires at Now + duration
@ -81,7 +84,8 @@ func (j *JWT) Token(ctx context.Context, user Principal, duration time.Duration)
now := j.Now().UTC()
claims := &Claims{
gojwt.StandardClaims{
Subject: string(user),
Subject: user.Subject,
Issuer: user.Issuer,
ExpiresAt: now.Add(duration).Unix(),
IssuedAt: now.Unix(),
NotBefore: now.Unix(),

View File

@ -21,35 +21,45 @@ func TestAuthenticate(t *testing.T) {
Desc: "Test bad jwt token",
Secret: "secret",
Token: "badtoken",
User: "",
Err: errors.New("token contains an invalid number of segments"),
User: oauth2.Principal{
Subject: "",
},
Err: errors.New("token contains an invalid number of segments"),
},
{
Desc: "Test valid jwt token",
Secret: "secret",
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQ0MDB9._rZ4gOIei9PizHOABH6kLcJTA3jm8ls0YnDxtz1qeUI",
User: "/chronograf/v1/users/1",
User: oauth2.Principal{
Subject: "/chronograf/v1/users/1",
},
},
{
Desc: "Test expired jwt token",
Secret: "secret",
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAxLCJuYmYiOi00NDY3NzQ0MDB9.vWXdm0-XQ_pW62yBpSISFFJN_yz0vqT9_INcUKTp5Q8",
User: "",
Err: errors.New("token is expired by 1s"),
User: oauth2.Principal{
Subject: "",
},
Err: errors.New("token is expired by 1s"),
},
{
Desc: "Test jwt token not before time",
Secret: "secret",
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIvY2hyb25vZ3JhZi92MS91c2Vycy8xIiwibmFtZSI6IkRvYyBCcm93biIsImlhdCI6LTQ0Njc3NDQwMCwiZXhwIjotNDQ2Nzc0NDAwLCJuYmYiOi00NDY3NzQzOTl9.TMGAhv57u1aosjc4ywKC7cElP1tKyQH7GmRF2ToAxlE",
User: "",
Err: errors.New("token is not valid yet"),
User: oauth2.Principal{
Subject: "",
},
Err: errors.New("token is not valid yet"),
},
{
Desc: "Test jwt with empty subject is invalid",
Secret: "secret",
Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOi00NDY3NzQ0MDAsImV4cCI6LTQ0Njc3NDQwMCwibmJmIjotNDQ2Nzc0NDAwfQ.gxsA6_Ei3s0f2I1TAtrrb8FmGiO25OqVlktlF_ylhX4",
User: "",
Err: errors.New("claim has no subject"),
User: oauth2.Principal{
Subject: "",
},
Err: errors.New("claim has no subject"),
},
}
for i, test := range tests {
@ -82,7 +92,10 @@ func TestToken(t *testing.T) {
return time.Unix(-446774400, 0)
},
}
if token, err := j.Token(context.Background(), oauth2.Principal("/chronograf/v1/users/1"), duration); err != nil {
p := oauth2.Principal{
Subject: "/chronograf/v1/users/1",
}
if token, err := j.Token(context.Background(), p, duration); err != nil {
t.Errorf("Error creating token for user: %v", err)
} else if token != expected {
t.Errorf("Error creating token; expected: %s actual: %s", "", token)

View File

@ -64,7 +64,10 @@ func (j *JWTMux) Login() http.Handler {
// We'll give our users 10 minutes from this point to type in their github password.
// If the callback is not received within 10 minutes, then authorization will fail.
csrf := randomString(32) // 32 is not important... just long
state, err := j.Auth.Token(r.Context(), Principal(csrf), 10*time.Minute)
p := Principal{
Subject: csrf,
}
state, err := j.Auth.Token(r.Context(), p, 10*time.Minute)
// This is likely an internal server error
if err != nil {
j.Logger.
@ -122,8 +125,12 @@ func (j *JWTMux) Callback() http.Handler {
return
}
p := Principal{
Subject: id,
Issuer: j.Provider.Name(),
}
// We create an auth token that will be used by all other endpoints to validate the principal has a claim
authToken, err := j.Auth.Token(r.Context(), Principal(id), j.cookie.Duration)
authToken, err := j.Auth.Token(r.Context(), p, j.cookie.Duration)
if err != nil {
log.Error("Unable to create cookie auth token ", err.Error())
http.Redirect(w, r, j.FailureURL, http.StatusTemporaryRedirect)

View File

@ -14,7 +14,7 @@ const (
// PrincipalKey is used to pass principal
// via context.Context to request-scoped
// functions.
PrincipalKey Principal = "principal"
PrincipalKey string = "principal"
)
var (
@ -25,7 +25,10 @@ var (
/* Types */
// Principal is any entity that can be authenticated
type Principal string
type Principal struct {
Subject string
Issuer string
}
/* Interfaces */

21
server/logout.go Normal file
View File

@ -0,0 +1,21 @@
package server
import "net/http"
// Logout chooses the correct provider logout route and redirects to it
func Logout(nextURL string, routes AuthRoutes) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
principal, err := getPrincipal(ctx)
if err != nil {
http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect)
return
}
route, ok := routes.Lookup(principal.Issuer)
if !ok {
http.Redirect(w, r, nextURL, http.StatusTemporaryRedirect)
return
}
http.Redirect(w, r, route.Logout, http.StatusTemporaryRedirect)
}
}

View File

@ -42,8 +42,8 @@ func (m *MuxOpts) UseGoogle() bool {
return m.TokenSecret != "" && m.GoogleClientID != "" && m.GoogleClientSecret != "" && m.PublicURL != ""
}
func (m *MuxOpts) Routes() []AuthRoute {
routes := []AuthRoute{}
func (m *MuxOpts) Routes() AuthRoutes {
routes := AuthRoutes{}
if m.UseGithub() {
routes = append(routes, NewGithubRoute())
}
@ -148,6 +148,11 @@ func NewMux(opts MuxOpts, service Service) http.Handler {
/* Authentication */
if opts.UseAuth {
// Create middleware to redirect to the appropriate provider logout
targetURL := "/"
router.GET("/oauth/logout", Logout(targetURL, authRoutes))
// Encapsulate the router with OAuth2
auth := AuthAPI(opts, router)
return Logger(opts.Logger, auth)
}
@ -157,6 +162,7 @@ func NewMux(opts MuxOpts, service Service) http.Handler {
}
// AuthAPI adds the OAuth routes if auth is enabled.
// TODO: this function is not great. Would be good if providers added their routes.
func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler {
auth := oauth2.NewJWT(opts.TokenSecret)
if opts.UseGithub() {
@ -192,7 +198,7 @@ func AuthAPI(opts MuxOpts, router *httprouter.Router) http.Handler {
tokenMiddleware := oauth2.AuthorizedToken(&auth, &oauth2.CookieExtractor{Name: "session"}, opts.Logger, router)
// Wrap the API with token validation middleware.
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") {
if strings.HasPrefix(r.URL.Path, "/chronograf/v1/") || r.URL.Path == "/oauth/logout" {
tokenMiddleware.ServeHTTP(w, r)
return
}

View File

@ -15,6 +15,19 @@ type AuthRoute struct {
Callback string `json:"callback"` // Callback is the route the provider calls to exchange the code/state
}
// AuthRoutes contains all OAuth2 provider routes.
type AuthRoutes []AuthRoute
// Lookup searches all the routes for a specific provider
func (r *AuthRoutes) Lookup(provider string) (AuthRoute, bool) {
for _, route := range *r {
if route.Name == provider {
return route, true
}
}
return AuthRoute{}, false
}
type getRoutesResponse struct {
Layouts string `json:"layouts"` // Location of the layouts endpoint
Mappings string `json:"mappings"` // Location of the application mappings endpoint

View File

@ -141,11 +141,23 @@ func ValidUserRequest(s *chronograf.User) error {
}
func getEmail(ctx context.Context) (string, error) {
principal := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal)
if principal == "" {
principal, err := getPrincipal(ctx)
if err != nil {
return "", err
}
if principal.Subject == "" {
return "", fmt.Errorf("Token not found")
}
return string(principal), nil
return principal.Subject, nil
}
func getPrincipal(ctx context.Context) (oauth2.Principal, error) {
principal, ok := ctx.Value(oauth2.PrincipalKey).(oauth2.Principal)
if !ok {
return oauth2.Principal{}, fmt.Errorf("Token not found")
}
return principal, nil
}
// Me does a findOrCreate based on the email in the context