Slightly DRYer code for getting fields off Principal

Signed-off-by: Michael de Sa <mjdesa@gmail.com>
pull/5028/head
Jared Scheib 2017-10-27 12:48:51 -07:00 committed by Michael de Sa
parent 511c3e1957
commit e0a535e78a
1 changed files with 24 additions and 34 deletions

View File

@ -38,6 +38,7 @@ func newMeResponse(usr *chronograf.User) meResponse {
}
}
// getUsername not currently used
func getUsername(ctx context.Context) (string, error) {
principal, err := getPrincipal(ctx)
if err != nil {
@ -49,6 +50,7 @@ func getUsername(ctx context.Context) (string, error) {
return principal.Subject, nil
}
// getProvider not currently used
func getProvider(ctx context.Context) (string, error) {
principal, err := getPrincipal(ctx)
if err != nil {
@ -76,14 +78,18 @@ func getPrincipal(ctx context.Context) (oauth2.Principal, error) {
return principal, nil
}
// This is the user's current chronograf organization and is not related to any
// concept of a OAuth organization.
func getOrganization(ctx context.Context) (string, error) {
principal, err := getPrincipal(ctx)
func getValidPrincipal(ctx context.Context) (oauth2.Principal, error) {
p, err := getPrincipal(ctx)
if err != nil {
return "", err
return p, err
}
return principal.Organization, nil
if p.Subject == "" {
return oauth2.Principal{}, fmt.Errorf("Token not found")
}
if p.Issuer == "" {
return oauth2.Principal{}, fmt.Errorf("Token not found")
}
return p, nil
}
type meOrganizationRequest struct {
@ -119,12 +125,7 @@ func (s *Service) MeOrganization(auth oauth2.Authenticator) func(http.ResponseWr
return
}
username, err := getUsername(ctx)
if err != nil {
invalidData(w, err, s.Logger)
return
}
provider, err := getProvider(ctx)
p, err := getValidPrincipal(ctx)
if err != nil {
invalidData(w, err, s.Logger)
return
@ -137,8 +138,8 @@ func (s *Service) MeOrganization(auth oauth2.Authenticator) func(http.ResponseWr
// validate that user belongs to organization
ctx = context.WithValue(ctx, "organizationID", req.OrganizationID)
_, err = s.OrganizationUsersStore.Get(ctx, chronograf.UserQuery{
Name: &username,
Provider: &provider,
Name: &p.Subject,
Provider: &p.Issuer,
Scheme: &scheme,
})
if err == chronograf.ErrUserNotFound {
@ -174,12 +175,7 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
return
}
username, err := getUsername(ctx)
if err != nil {
invalidData(w, err, s.Logger)
return
}
provider, err := getProvider(ctx)
p, err := getValidPrincipal(ctx)
if err != nil {
invalidData(w, err, s.Logger)
return
@ -189,18 +185,12 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
invalidData(w, err, s.Logger)
return
}
organization, err := getOrganization(ctx)
if err != nil {
invalidData(w, err, s.Logger)
return
}
// TODO: add real implementation
ctx = context.WithValue(ctx, "organizationID", organization)
ctx = context.WithValue(ctx, "organizationID", p.Organization)
usr, err := s.UsersStore.Get(ctx, chronograf.UserQuery{
Name: &username,
Provider: &provider,
Name: &p.Subject,
Provider: &p.Issuer,
Scheme: &scheme,
})
if err != nil && err != chronograf.ErrUserNotFound {
@ -209,7 +199,7 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
}
if usr != nil {
usr.CurrentOrganization = organization
usr.CurrentOrganization = p.Organization
res := newMeResponse(usr)
encodeJSON(w, http.StatusOK, res, s.Logger)
return
@ -217,12 +207,12 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
// Because we didnt find a user, making a new one
user := &chronograf.User{
Name: username,
Provider: provider,
Name: p.Subject,
Provider: p.Issuer,
// TODO: This Scheme value is hard-coded temporarily since we only currently
// support OAuth2. This hard-coding should be removed whenever we add
// support for other authentication schemes.
Scheme: "oauth2",
Scheme: scheme,
}
newUser, err := s.UsersStore.Add(ctx, user)
@ -232,7 +222,7 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
return
}
newUser.CurrentOrganization = organization
newUser.CurrentOrganization = p.Organization
res := newMeResponse(newUser)
encodeJSON(w, http.StatusOK, res, s.Logger)
}