User explicit type when setting context
parent
dd7dac6a5f
commit
3eaca382d3
|
@ -78,7 +78,7 @@ func TestDashboards_All(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
gots, err := s.All(tt.args.ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. DashboardsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -140,7 +140,7 @@ func TestDashboards_Add(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
d, err := s.Add(tt.args.ctx, tt.args.dashboard)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. DashboardsStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -200,7 +200,7 @@ func TestDashboards_Delete(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Delete(tt.args.ctx, tt.args.dashboard)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. DashboardsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -257,7 +257,7 @@ func TestDashboards_Get(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
got, err := s.Get(tt.args.ctx, tt.args.dashboard.ID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. DashboardsStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -325,7 +325,7 @@ func TestDashboards_Update(t *testing.T) {
|
|||
tt.args.dashboard.Name = tt.args.name
|
||||
}
|
||||
s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Update(tt.args.ctx, tt.args.dashboard)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. DashboardsStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
|
|
@ -78,7 +78,7 @@ func TestLayouts_All(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
gots, err := s.All(tt.args.ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. LayoutsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -140,7 +140,7 @@ func TestLayouts_Add(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
d, err := s.Add(tt.args.ctx, tt.args.layout)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. LayoutsStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -200,7 +200,7 @@ func TestLayouts_Delete(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Delete(tt.args.ctx, tt.args.layout)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. LayoutsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -257,7 +257,7 @@ func TestLayouts_Get(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
got, err := s.Get(tt.args.ctx, tt.args.layout.ID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. LayoutsStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -325,7 +325,7 @@ func TestLayouts_Update(t *testing.T) {
|
|||
tt.args.layout.Application = tt.args.name
|
||||
}
|
||||
s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Update(tt.args.ctx, tt.args.layout)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. LayoutsStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
|
|
@ -5,14 +5,16 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
const organizationKey = "organizationID"
|
||||
type contextKey string
|
||||
|
||||
const ContextKey = contextKey("organization")
|
||||
|
||||
func validOrganization(ctx context.Context) error {
|
||||
// prevents panic in case of nil context
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("expect non nil context")
|
||||
}
|
||||
orgID, ok := ctx.Value(organizationKey).(string)
|
||||
orgID, ok := ctx.Value(ContextKey).(string)
|
||||
// should never happen
|
||||
if !ok {
|
||||
return fmt.Errorf("expected organization key to be a string")
|
||||
|
|
|
@ -79,7 +79,7 @@ func TestServers_All(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
gots, err := s.All(tt.args.ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. ServersStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -141,7 +141,7 @@ func TestServers_Add(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
d, err := s.Add(tt.args.ctx, tt.args.server)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. ServersStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -201,7 +201,7 @@ func TestServers_Delete(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Delete(tt.args.ctx, tt.args.server)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. ServersStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -258,7 +258,7 @@ func TestServers_Get(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
got, err := s.Get(tt.args.ctx, tt.args.server.ID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. ServersStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -326,7 +326,7 @@ func TestServers_Update(t *testing.T) {
|
|||
tt.args.server.Name = tt.args.name
|
||||
}
|
||||
s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Update(tt.args.ctx, tt.args.server)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. ServersStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
|
|
@ -79,7 +79,7 @@ func TestSources_All(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
gots, err := s.All(tt.args.ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. SourcesStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -141,7 +141,7 @@ func TestSources_Add(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
d, err := s.Add(tt.args.ctx, tt.args.source)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. SourcesStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -201,7 +201,7 @@ func TestSources_Delete(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Delete(tt.args.ctx, tt.args.source)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. SourcesStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -258,7 +258,7 @@ func TestSources_Get(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
got, err := s.Get(tt.args.ctx, tt.args.source.ID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. SourcesStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -326,7 +326,7 @@ func TestSources_Update(t *testing.T) {
|
|||
tt.args.source.Name = tt.args.name
|
||||
}
|
||||
s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization)
|
||||
err := s.Update(tt.args.ctx, tt.args.source)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. SourcesStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
|
|
@ -132,7 +132,7 @@ func TestUsersStore_Get(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID)
|
||||
got, err := s.Get(tt.args.ctx, chronograf.UserQuery{ID: &tt.args.userID})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. UsersStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -448,7 +448,7 @@ func TestUsersStore_Add(t *testing.T) {
|
|||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID)
|
||||
s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID)
|
||||
|
||||
got, err := s.Add(tt.args.ctx, tt.args.u)
|
||||
|
@ -548,7 +548,7 @@ func TestUsersStore_Delete(t *testing.T) {
|
|||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID)
|
||||
s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID)
|
||||
if err := s.Delete(tt.args.ctx, tt.args.user); (err != nil) != tt.wantErr {
|
||||
t.Errorf("%q. UsersStore.Delete() error = %v, wantErr %v", tt.name, err, tt.wantErr)
|
||||
|
@ -648,7 +648,7 @@ func TestUsersStore_Update(t *testing.T) {
|
|||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID)
|
||||
tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID)
|
||||
s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID)
|
||||
|
||||
if tt.args.roles != nil {
|
||||
|
@ -806,7 +806,7 @@ func TestUsersStore_All(t *testing.T) {
|
|||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt.ctx = context.WithValue(tt.ctx, "organizationID", tt.orgID)
|
||||
tt.ctx = context.WithValue(tt.ctx, organizations.ContextKey, tt.orgID)
|
||||
for _, u := range tt.wantRaw {
|
||||
tt.fields.UsersStore.Add(tt.ctx, &u)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/influxdata/chronograf"
|
||||
"github.com/influxdata/chronograf/oauth2"
|
||||
"github.com/influxdata/chronograf/organizations"
|
||||
)
|
||||
|
||||
// AuthorizedToken extracts the token and validates; if valid the next handler
|
||||
|
@ -104,9 +105,8 @@ func AuthorizedUser(
|
|||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, "organizationID", p.Organization)
|
||||
// TODO: add real implementation
|
||||
serverCtx := context.WithValue(ctx, "superadmin", true)
|
||||
ctx = context.WithValue(ctx, organizations.ContextKey, p.Organization)
|
||||
serverCtx := context.WithValue(ctx, SuperAdminKey, true)
|
||||
// TODO: seems silly to look up a user twice
|
||||
u, err := store.Users(serverCtx).Get(serverCtx, chronograf.UserQuery{
|
||||
Name: &p.Subject,
|
||||
|
|
16
server/me.go
16
server/me.go
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/influxdata/chronograf"
|
||||
"github.com/influxdata/chronograf/oauth2"
|
||||
"github.com/influxdata/chronograf/organizations"
|
||||
)
|
||||
|
||||
type meLinks struct {
|
||||
|
@ -136,7 +137,7 @@ func (s *Service) MeOrganization(auth oauth2.Authenticator) func(http.ResponseWr
|
|||
return
|
||||
}
|
||||
// validate that user belongs to organization
|
||||
ctx = context.WithValue(ctx, "organizationID", req.OrganizationID)
|
||||
ctx = context.WithValue(ctx, organizations.ContextKey, req.OrganizationID)
|
||||
_, err = s.Store.Users(ctx).Get(ctx, chronograf.UserQuery{
|
||||
Name: &p.Subject,
|
||||
Provider: &p.Issuer,
|
||||
|
@ -185,11 +186,8 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
|
|||
invalidData(w, err, s.Logger)
|
||||
return
|
||||
}
|
||||
// TODO: add real implementation
|
||||
ctx = context.WithValue(ctx, "organizationID", p.Organization)
|
||||
|
||||
// TODO: add real implementation
|
||||
ctx = context.WithValue(ctx, "superadmin", true)
|
||||
ctx = context.WithValue(ctx, organizations.ContextKey, p.Organization)
|
||||
ctx = context.WithValue(ctx, SuperAdminKey, true)
|
||||
|
||||
usr, err := s.Store.Users(ctx).Get(ctx, chronograf.UserQuery{
|
||||
Name: &p.Subject,
|
||||
|
@ -216,7 +214,6 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
|
|||
// support OAuth2. This hard-coding should be removed whenever we add
|
||||
// support for other authentication schemes.
|
||||
Scheme: scheme,
|
||||
// TODO: this should be member
|
||||
Roles: []chronograf.Role{
|
||||
{
|
||||
Name: MemberRoleName,
|
||||
|
@ -227,9 +224,6 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
|
|||
SuperAdmin: s.firstUser(),
|
||||
}
|
||||
|
||||
// TODO: add real implementation
|
||||
ctx = context.WithValue(ctx, "superadmin", true)
|
||||
|
||||
newUser, err := s.Store.Users(ctx).Add(ctx, user)
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error storing user %s: %v", user.Name, err)
|
||||
|
@ -244,7 +238,7 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// TODO(desa): very slow
|
||||
func (s *Service) firstUser() bool {
|
||||
ctx := context.WithValue(context.Background(), "superadmin", true)
|
||||
ctx := context.WithValue(context.Background(), SuperAdminKey, true)
|
||||
users, err := s.Store.Users(ctx).All(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
|
|
|
@ -8,14 +8,12 @@ import (
|
|||
"github.com/influxdata/chronograf/organizations"
|
||||
)
|
||||
|
||||
const organizationKey = "organizationID"
|
||||
|
||||
func hasOrganizationContext(ctx context.Context) (string, bool) {
|
||||
// prevents panic in case of nil context
|
||||
if ctx == nil {
|
||||
return "", false
|
||||
}
|
||||
orgID, ok := ctx.Value(organizationKey).(string)
|
||||
orgID, ok := ctx.Value(organizations.ContextKey).(string)
|
||||
// should never happen
|
||||
if !ok {
|
||||
return "", false
|
||||
|
@ -26,14 +24,16 @@ func hasOrganizationContext(ctx context.Context) (string, bool) {
|
|||
return orgID, true
|
||||
}
|
||||
|
||||
const superAdminKey = "superadmin"
|
||||
type superAdminKey string
|
||||
|
||||
const SuperAdminKey = superAdminKey("superadmin")
|
||||
|
||||
func hasSuperAdminContext(ctx context.Context) (bool, bool) {
|
||||
// prevents panic in case of nil context
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
sa, ok := ctx.Value(superAdminKey).(bool)
|
||||
sa, ok := ctx.Value(SuperAdminKey).(bool)
|
||||
// should never happen
|
||||
if !ok {
|
||||
return false, false
|
||||
|
|
Loading…
Reference in New Issue