diff --git a/organizations/dashboards_test.go b/organizations/dashboards_test.go index 7c3c91785..d45bb6354 100644 --- a/organizations/dashboards_test.go +++ b/organizations/dashboards_test.go @@ -78,7 +78,7 @@ func TestDashboards_All(t *testing.T) { } for _, tt := range tests { s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) gots, err := s.All(tt.args.ctx) if (err != nil) != tt.wantErr { t.Errorf("%q. DashboardsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -140,7 +140,7 @@ func TestDashboards_Add(t *testing.T) { } for _, tt := range tests { s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) d, err := s.Add(tt.args.ctx, tt.args.dashboard) if (err != nil) != tt.wantErr { t.Errorf("%q. DashboardsStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -200,7 +200,7 @@ func TestDashboards_Delete(t *testing.T) { } for _, tt := range tests { s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Delete(tt.args.ctx, tt.args.dashboard) if (err != nil) != tt.wantErr { t.Errorf("%q. DashboardsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -257,7 +257,7 @@ func TestDashboards_Get(t *testing.T) { } for _, tt := range tests { s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) got, err := s.Get(tt.args.ctx, tt.args.dashboard.ID) if (err != nil) != tt.wantErr { t.Errorf("%q. DashboardsStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -325,7 +325,7 @@ func TestDashboards_Update(t *testing.T) { tt.args.dashboard.Name = tt.args.name } s := organizations.NewDashboardsStore(tt.fields.DashboardsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Update(tt.args.ctx, tt.args.dashboard) if (err != nil) != tt.wantErr { t.Errorf("%q. DashboardsStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr) diff --git a/organizations/layouts_test.go b/organizations/layouts_test.go index 94f51b7de..5969b4cc1 100644 --- a/organizations/layouts_test.go +++ b/organizations/layouts_test.go @@ -78,7 +78,7 @@ func TestLayouts_All(t *testing.T) { } for _, tt := range tests { s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) gots, err := s.All(tt.args.ctx) if (err != nil) != tt.wantErr { t.Errorf("%q. LayoutsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -140,7 +140,7 @@ func TestLayouts_Add(t *testing.T) { } for _, tt := range tests { s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) d, err := s.Add(tt.args.ctx, tt.args.layout) if (err != nil) != tt.wantErr { t.Errorf("%q. LayoutsStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -200,7 +200,7 @@ func TestLayouts_Delete(t *testing.T) { } for _, tt := range tests { s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Delete(tt.args.ctx, tt.args.layout) if (err != nil) != tt.wantErr { t.Errorf("%q. LayoutsStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -257,7 +257,7 @@ func TestLayouts_Get(t *testing.T) { } for _, tt := range tests { s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) got, err := s.Get(tt.args.ctx, tt.args.layout.ID) if (err != nil) != tt.wantErr { t.Errorf("%q. LayoutsStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -325,7 +325,7 @@ func TestLayouts_Update(t *testing.T) { tt.args.layout.Application = tt.args.name } s := organizations.NewLayoutsStore(tt.fields.LayoutsStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Update(tt.args.ctx, tt.args.layout) if (err != nil) != tt.wantErr { t.Errorf("%q. LayoutsStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr) diff --git a/organizations/organizations.go b/organizations/organizations.go index cb757a2d1..fe5f94e6a 100644 --- a/organizations/organizations.go +++ b/organizations/organizations.go @@ -5,14 +5,16 @@ import ( "fmt" ) -const organizationKey = "organizationID" +type contextKey string + +const ContextKey = contextKey("organization") func validOrganization(ctx context.Context) error { // prevents panic in case of nil context if ctx == nil { return fmt.Errorf("expect non nil context") } - orgID, ok := ctx.Value(organizationKey).(string) + orgID, ok := ctx.Value(ContextKey).(string) // should never happen if !ok { return fmt.Errorf("expected organization key to be a string") diff --git a/organizations/servers_test.go b/organizations/servers_test.go index e714b133b..0af4b524b 100644 --- a/organizations/servers_test.go +++ b/organizations/servers_test.go @@ -79,7 +79,7 @@ func TestServers_All(t *testing.T) { } for _, tt := range tests { s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) gots, err := s.All(tt.args.ctx) if (err != nil) != tt.wantErr { t.Errorf("%q. ServersStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -141,7 +141,7 @@ func TestServers_Add(t *testing.T) { } for _, tt := range tests { s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) d, err := s.Add(tt.args.ctx, tt.args.server) if (err != nil) != tt.wantErr { t.Errorf("%q. ServersStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -201,7 +201,7 @@ func TestServers_Delete(t *testing.T) { } for _, tt := range tests { s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Delete(tt.args.ctx, tt.args.server) if (err != nil) != tt.wantErr { t.Errorf("%q. ServersStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -258,7 +258,7 @@ func TestServers_Get(t *testing.T) { } for _, tt := range tests { s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) got, err := s.Get(tt.args.ctx, tt.args.server.ID) if (err != nil) != tt.wantErr { t.Errorf("%q. ServersStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -326,7 +326,7 @@ func TestServers_Update(t *testing.T) { tt.args.server.Name = tt.args.name } s := organizations.NewServersStore(tt.fields.ServersStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Update(tt.args.ctx, tt.args.server) if (err != nil) != tt.wantErr { t.Errorf("%q. ServersStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr) diff --git a/organizations/sources_test.go b/organizations/sources_test.go index 46a0140c7..3496167fd 100644 --- a/organizations/sources_test.go +++ b/organizations/sources_test.go @@ -79,7 +79,7 @@ func TestSources_All(t *testing.T) { } for _, tt := range tests { s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) gots, err := s.All(tt.args.ctx) if (err != nil) != tt.wantErr { t.Errorf("%q. SourcesStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -141,7 +141,7 @@ func TestSources_Add(t *testing.T) { } for _, tt := range tests { s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) d, err := s.Add(tt.args.ctx, tt.args.source) if (err != nil) != tt.wantErr { t.Errorf("%q. SourcesStore.Add() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -201,7 +201,7 @@ func TestSources_Delete(t *testing.T) { } for _, tt := range tests { s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Delete(tt.args.ctx, tt.args.source) if (err != nil) != tt.wantErr { t.Errorf("%q. SourcesStore.All() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -258,7 +258,7 @@ func TestSources_Get(t *testing.T) { } for _, tt := range tests { s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) got, err := s.Get(tt.args.ctx, tt.args.source.ID) if (err != nil) != tt.wantErr { t.Errorf("%q. SourcesStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -326,7 +326,7 @@ func TestSources_Update(t *testing.T) { tt.args.source.Name = tt.args.name } s := organizations.NewSourcesStore(tt.fields.SourcesStore, tt.args.organization) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.organization) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.organization) err := s.Update(tt.args.ctx, tt.args.source) if (err != nil) != tt.wantErr { t.Errorf("%q. SourcesStore.Update() error = %v, wantErr %v", tt.name, err, tt.wantErr) diff --git a/organizations/users_test.go b/organizations/users_test.go index b4ec05f9d..7973bdd4c 100644 --- a/organizations/users_test.go +++ b/organizations/users_test.go @@ -132,7 +132,7 @@ func TestUsersStore_Get(t *testing.T) { } for _, tt := range tests { s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID) - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID) got, err := s.Get(tt.args.ctx, chronograf.UserQuery{ID: &tt.args.userID}) if (err != nil) != tt.wantErr { t.Errorf("%q. UsersStore.Get() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -448,7 +448,7 @@ func TestUsersStore_Add(t *testing.T) { }, } for _, tt := range tests { - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID) s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID) got, err := s.Add(tt.args.ctx, tt.args.u) @@ -548,7 +548,7 @@ func TestUsersStore_Delete(t *testing.T) { }, } for _, tt := range tests { - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID) s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID) if err := s.Delete(tt.args.ctx, tt.args.user); (err != nil) != tt.wantErr { t.Errorf("%q. UsersStore.Delete() error = %v, wantErr %v", tt.name, err, tt.wantErr) @@ -648,7 +648,7 @@ func TestUsersStore_Update(t *testing.T) { }, } for _, tt := range tests { - tt.args.ctx = context.WithValue(tt.args.ctx, "organizationID", tt.args.orgID) + tt.args.ctx = context.WithValue(tt.args.ctx, organizations.ContextKey, tt.args.orgID) s := organizations.NewUsersStore(tt.fields.UsersStore, tt.args.orgID) if tt.args.roles != nil { @@ -806,7 +806,7 @@ func TestUsersStore_All(t *testing.T) { }, } for _, tt := range tests { - tt.ctx = context.WithValue(tt.ctx, "organizationID", tt.orgID) + tt.ctx = context.WithValue(tt.ctx, organizations.ContextKey, tt.orgID) for _, u := range tt.wantRaw { tt.fields.UsersStore.Add(tt.ctx, &u) } diff --git a/server/auth.go b/server/auth.go index 0ff123567..11e7a786c 100644 --- a/server/auth.go +++ b/server/auth.go @@ -8,6 +8,7 @@ import ( "github.com/influxdata/chronograf" "github.com/influxdata/chronograf/oauth2" + "github.com/influxdata/chronograf/organizations" ) // AuthorizedToken extracts the token and validates; if valid the next handler @@ -104,9 +105,8 @@ func AuthorizedUser( return } - ctx = context.WithValue(ctx, "organizationID", p.Organization) - // TODO: add real implementation - serverCtx := context.WithValue(ctx, "superadmin", true) + ctx = context.WithValue(ctx, organizations.ContextKey, p.Organization) + serverCtx := context.WithValue(ctx, SuperAdminKey, true) // TODO: seems silly to look up a user twice u, err := store.Users(serverCtx).Get(serverCtx, chronograf.UserQuery{ Name: &p.Subject, diff --git a/server/me.go b/server/me.go index ae095fa2d..384fc610b 100644 --- a/server/me.go +++ b/server/me.go @@ -10,6 +10,7 @@ import ( "github.com/influxdata/chronograf" "github.com/influxdata/chronograf/oauth2" + "github.com/influxdata/chronograf/organizations" ) type meLinks struct { @@ -136,7 +137,7 @@ func (s *Service) MeOrganization(auth oauth2.Authenticator) func(http.ResponseWr return } // validate that user belongs to organization - ctx = context.WithValue(ctx, "organizationID", req.OrganizationID) + ctx = context.WithValue(ctx, organizations.ContextKey, req.OrganizationID) _, err = s.Store.Users(ctx).Get(ctx, chronograf.UserQuery{ Name: &p.Subject, Provider: &p.Issuer, @@ -185,11 +186,8 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) { invalidData(w, err, s.Logger) return } - // TODO: add real implementation - ctx = context.WithValue(ctx, "organizationID", p.Organization) - - // TODO: add real implementation - ctx = context.WithValue(ctx, "superadmin", true) + ctx = context.WithValue(ctx, organizations.ContextKey, p.Organization) + ctx = context.WithValue(ctx, SuperAdminKey, true) usr, err := s.Store.Users(ctx).Get(ctx, chronograf.UserQuery{ Name: &p.Subject, @@ -216,7 +214,6 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) { // support OAuth2. This hard-coding should be removed whenever we add // support for other authentication schemes. Scheme: scheme, - // TODO: this should be member Roles: []chronograf.Role{ { Name: MemberRoleName, @@ -227,9 +224,6 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) { SuperAdmin: s.firstUser(), } - // TODO: add real implementation - ctx = context.WithValue(ctx, "superadmin", true) - newUser, err := s.Store.Users(ctx).Add(ctx, user) if err != nil { msg := fmt.Errorf("error storing user %s: %v", user.Name, err) @@ -244,7 +238,7 @@ func (s *Service) Me(w http.ResponseWriter, r *http.Request) { // TODO(desa): very slow func (s *Service) firstUser() bool { - ctx := context.WithValue(context.Background(), "superadmin", true) + ctx := context.WithValue(context.Background(), SuperAdminKey, true) users, err := s.Store.Users(ctx).All(ctx) if err != nil { return false diff --git a/server/stores.go b/server/stores.go index 0a1a4280a..9237429f2 100644 --- a/server/stores.go +++ b/server/stores.go @@ -8,14 +8,12 @@ import ( "github.com/influxdata/chronograf/organizations" ) -const organizationKey = "organizationID" - func hasOrganizationContext(ctx context.Context) (string, bool) { // prevents panic in case of nil context if ctx == nil { return "", false } - orgID, ok := ctx.Value(organizationKey).(string) + orgID, ok := ctx.Value(organizations.ContextKey).(string) // should never happen if !ok { return "", false @@ -26,14 +24,16 @@ func hasOrganizationContext(ctx context.Context) (string, bool) { return orgID, true } -const superAdminKey = "superadmin" +type superAdminKey string + +const SuperAdminKey = superAdminKey("superadmin") func hasSuperAdminContext(ctx context.Context) (bool, bool) { // prevents panic in case of nil context if ctx == nil { return false, false } - sa, ok := ctx.Value(superAdminKey).(bool) + sa, ok := ctx.Value(SuperAdminKey).(bool) // should never happen if !ok { return false, false