(feat/testing) add onboarding and basic auth service
parent
3e54ef9f53
commit
63da5d1e9f
|
@ -0,0 +1,69 @@
|
||||||
|
package bolt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
bolt "github.com/coreos/bbolt"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetPassword stores the password hash associated with a user.
|
||||||
|
func (c *Client) SetPassword(ctx context.Context, name string, password string) error {
|
||||||
|
return c.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
return c.setPassword(ctx, tx, name, password)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashCost currently using the default cost of bcrypt
|
||||||
|
var HashCost = bcrypt.DefaultCost
|
||||||
|
|
||||||
|
func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := c.findUserByName(ctx, tx, name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedID, err := u.ID.Encode()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Bucket(userpasswordBucket).Put(encodedID, hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComparePassword compares a provided password with the stored password hash.
|
||||||
|
func (c *Client) ComparePassword(ctx context.Context, name string, password string) error {
|
||||||
|
return c.db.View(func(tx *bolt.Tx) error {
|
||||||
|
return c.comparePassword(ctx, tx, name, password)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
|
||||||
|
u, err := c.findUserByName(ctx, tx, name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedID, err := u.ID.Encode()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := tx.Bucket(userpasswordBucket).Get(encodedID)
|
||||||
|
|
||||||
|
return bcrypt.CompareHashAndPassword(hash, []byte(password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
|
||||||
|
func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
|
||||||
|
return c.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
if err := c.comparePassword(ctx, tx, name, old); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setPassword(ctx, tx, name, new)
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package bolt_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
platformtesting "github.com/influxdata/platform/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initBasicAuthService(f platformtesting.UserFields, t *testing.T) (platform.BasicAuthService, func()) {
|
||||||
|
c, closeFn, err := NewTestClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create new bolt client: %v", err)
|
||||||
|
}
|
||||||
|
c.IDGenerator = f.IDGenerator
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, u := range f.Users {
|
||||||
|
if err := c.PutUser(ctx, u); err != nil {
|
||||||
|
t.Fatalf("failed to populate users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c, func() {
|
||||||
|
defer closeFn()
|
||||||
|
for _, u := range f.Users {
|
||||||
|
if err := c.DeleteUser(ctx, u.ID); err != nil {
|
||||||
|
t.Logf("failed to remove users: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
platformtesting.BasicAuth(initBasicAuthService, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth_CompareAndSet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
platformtesting.CompareAndSetPassword(initBasicAuthService, t)
|
||||||
|
}
|
|
@ -4,10 +4,11 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
platformtesting "github.com/influxdata/platform/testing"
|
platformtesting "github.com/influxdata/platform/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platformtesting.OnBoardingNBasicAuthService, func()) {
|
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
|
||||||
c, closeFn, err := NewTestClient()
|
c, closeFn, err := NewTestClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create new bolt client: %v", err)
|
t.Fatalf("failed to create new bolt client: %v", err)
|
||||||
|
|
61
bolt/user.go
61
bolt/user.go
|
@ -7,7 +7,6 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
"github.com/influxdata/platform"
|
"github.com/influxdata/platform"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -341,63 +340,3 @@ func (c *Client) deleteUsersAuthorizations(ctx context.Context, tx *bolt.Tx, id
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPassword stores the password hash associated with a user.
|
|
||||||
func (c *Client) SetPassword(ctx context.Context, name string, password string) error {
|
|
||||||
return c.db.Update(func(tx *bolt.Tx) error {
|
|
||||||
return c.setPassword(ctx, tx, name, password)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var HashCost = bcrypt.DefaultCost
|
|
||||||
|
|
||||||
func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
|
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := c.findUserByName(ctx, tx, name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedID, err := u.ID.Encode()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Bucket(userpasswordBucket).Put(encodedID, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComparePassword compares a provided password with the stored password hash.
|
|
||||||
func (c *Client) ComparePassword(ctx context.Context, name string, password string) error {
|
|
||||||
return c.db.View(func(tx *bolt.Tx) error {
|
|
||||||
return c.comparePassword(ctx, tx, name, password)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
|
|
||||||
u, err := c.findUserByName(ctx, tx, name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedID, err := u.ID.Encode()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := tx.Bucket(userpasswordBucket).Get(encodedID)
|
|
||||||
|
|
||||||
return bcrypt.CompareHashAndPassword(hash, []byte(password))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
|
|
||||||
func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
|
|
||||||
return c.db.Update(func(tx *bolt.Tx) error {
|
|
||||||
if err := c.comparePassword(ctx, tx, name, old); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.setPassword(ctx, tx, name, new)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ func filterMappingsFn(filter platform.UserResourceMappingFilter) func(m *platfor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindUserResourceMappings returns a list of UserResourceMappings that match filter and the total count of matching mappings.
|
||||||
func (c *Client) FindUserResourceMappings(ctx context.Context, filter platform.UserResourceMappingFilter, opt ...platform.FindOptions) ([]*platform.UserResourceMapping, int, error) {
|
func (c *Client) FindUserResourceMappings(ctx context.Context, filter platform.UserResourceMappingFilter, opt ...platform.FindOptions) ([]*platform.UserResourceMapping, int, error) {
|
||||||
ms := []*platform.UserResourceMapping{}
|
ms := []*platform.UserResourceMapping{}
|
||||||
err := c.db.View(func(tx *bolt.Tx) error {
|
err := c.db.View(func(tx *bolt.Tx) error {
|
||||||
|
@ -137,6 +138,7 @@ func (c *Client) uniqueUserResourceMapping(ctx context.Context, tx *bolt.Tx, m *
|
||||||
return len(v) == 0
|
return len(v) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUserResourceMapping deletes a user resource mapping.
|
||||||
func (c *Client) DeleteUserResourceMapping(ctx context.Context, resourceID platform.ID, userID platform.ID) error {
|
func (c *Client) DeleteUserResourceMapping(ctx context.Context, resourceID platform.ID, userID platform.ID) error {
|
||||||
return c.db.Update(func(tx *bolt.Tx) error {
|
return c.db.Update(func(tx *bolt.Tx) error {
|
||||||
return c.deleteUserResourceMapping(ctx, tx, platform.UserResourceMappingFilter{
|
return c.deleteUserResourceMapping(ctx, tx, platform.UserResourceMappingFilter{
|
||||||
|
|
|
@ -2,7 +2,6 @@ package bolt_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/influxdata/platform"
|
"github.com/influxdata/platform"
|
||||||
|
@ -54,186 +53,3 @@ func TestUserService_FindUser(t *testing.T) {
|
||||||
func TestUserService_UpdateUser(t *testing.T) {
|
func TestUserService_UpdateUser(t *testing.T) {
|
||||||
platformtesting.UpdateUser(initUserService, t)
|
platformtesting.UpdateUser(initUserService, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBasicAuth(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
users []*platform.User
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
user string
|
|
||||||
setPassword string
|
|
||||||
comparePassword string
|
|
||||||
}
|
|
||||||
type wants struct {
|
|
||||||
setErr error
|
|
||||||
compareErr error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
wants wants
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
fields: fields{
|
|
||||||
users: []*platform.User{
|
|
||||||
{
|
|
||||||
Name: "user1",
|
|
||||||
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
name: "happy path",
|
|
||||||
user: "user1",
|
|
||||||
setPassword: "hello",
|
|
||||||
comparePassword: "hello",
|
|
||||||
},
|
|
||||||
wants: wants{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
fields: fields{
|
|
||||||
users: []*platform.User{
|
|
||||||
{
|
|
||||||
Name: "user1",
|
|
||||||
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
name: "happy path dont match",
|
|
||||||
user: "user1",
|
|
||||||
setPassword: "hello",
|
|
||||||
comparePassword: "world",
|
|
||||||
},
|
|
||||||
wants: wants{
|
|
||||||
compareErr: fmt.Errorf("crypto/bcrypt: hashedPassword is not the hash of the given password"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.args.name, func(t *testing.T) {
|
|
||||||
c, closeFn, err := NewTestClient()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create new bolt client: %v", err)
|
|
||||||
}
|
|
||||||
defer closeFn()
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
for _, user := range tt.fields.users {
|
|
||||||
if err := c.PutUser(ctx, user); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.SetPassword(ctx, tt.args.user, tt.args.setPassword)
|
|
||||||
|
|
||||||
if (err != nil && tt.wants.setErr == nil) || (err == nil && tt.wants.setErr != nil) {
|
|
||||||
t.Fatalf("expected SetPassword error %v got %v", tt.wants.setErr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if want, got := tt.wants.setErr.Error(), err.Error(); want != got {
|
|
||||||
t.Fatalf("expected SetPassword error %v got %v", want, got)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.ComparePassword(ctx, tt.args.user, tt.args.comparePassword)
|
|
||||||
|
|
||||||
if (err != nil && tt.wants.compareErr == nil) || (err == nil && tt.wants.compareErr != nil) {
|
|
||||||
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if want, got := tt.wants.compareErr.Error(), err.Error(); want != got {
|
|
||||||
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBasicAuth_CompareAndSet(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
users []*platform.User
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
user string
|
|
||||||
old string
|
|
||||||
new string
|
|
||||||
}
|
|
||||||
type wants struct {
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
wants wants
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
fields: fields{
|
|
||||||
users: []*platform.User{
|
|
||||||
{
|
|
||||||
Name: "user1",
|
|
||||||
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
name: "happy path",
|
|
||||||
user: "user1",
|
|
||||||
old: "hello",
|
|
||||||
new: "hello",
|
|
||||||
},
|
|
||||||
wants: wants{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.args.name, func(t *testing.T) {
|
|
||||||
c, closeFn, err := NewTestClient()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create new bolt client: %v", err)
|
|
||||||
}
|
|
||||||
defer closeFn()
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
for _, user := range tt.fields.users {
|
|
||||||
if err := c.PutUser(ctx, user); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.SetPassword(ctx, tt.args.user, tt.args.old); err != nil {
|
|
||||||
t.Fatalf("unexpected error %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.CompareAndSetPassword(ctx, tt.args.user, tt.args.old, tt.args.new)
|
|
||||||
|
|
||||||
if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) {
|
|
||||||
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if want, got := tt.wants.err.Error(), err.Error(); want != got {
|
|
||||||
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -52,7 +52,8 @@ func CheckErrorStatus(code int, res *http.Response) error {
|
||||||
// be determined in that way, it will create a generic error message.
|
// be determined in that way, it will create a generic error message.
|
||||||
//
|
//
|
||||||
// If there is no error, then this returns nil.
|
// If there is no error, then this returns nil.
|
||||||
func CheckError(resp *http.Response) error {
|
// Add an optional isPlatformError, to do decode with platform.Error
|
||||||
|
func CheckError(resp *http.Response, isPlatformError ...bool) (err error) {
|
||||||
switch resp.StatusCode / 100 {
|
switch resp.StatusCode / 100 {
|
||||||
case 4, 5:
|
case 4, 5:
|
||||||
// We will attempt to parse this error outside of this block.
|
// We will attempt to parse this error outside of this block.
|
||||||
|
@ -62,7 +63,15 @@ func CheckError(resp *http.Response) error {
|
||||||
// TODO(jsternberg): Figure out what to do here?
|
// TODO(jsternberg): Figure out what to do here?
|
||||||
return kerrors.InternalErrorf("unexpected status code: %d %s", resp.StatusCode, resp.Status)
|
return kerrors.InternalErrorf("unexpected status code: %d %s", resp.StatusCode, resp.Status)
|
||||||
}
|
}
|
||||||
|
if len(isPlatformError) > 0 && isPlatformError[0] {
|
||||||
|
pe := new(platform.Error)
|
||||||
|
parseErr := json.NewDecoder(resp.Body).Decode(pe)
|
||||||
|
if parseErr != nil {
|
||||||
|
return parseErr
|
||||||
|
}
|
||||||
|
err = pe
|
||||||
|
return err
|
||||||
|
}
|
||||||
// Attempt to read the X-Influx-Error header with the message.
|
// Attempt to read the X-Influx-Error header with the message.
|
||||||
if errMsg := resp.Header.Get(ErrorHeader); errMsg != "" {
|
if errMsg := resp.Header.Get(ErrorHeader); errMsg != "" {
|
||||||
// Parse the reference number as an integer. If we cannot parse it,
|
// Parse the reference number as an integer. If we cannot parse it,
|
||||||
|
|
|
@ -148,8 +148,9 @@ func (s *SetupService) Generate(ctx context.Context, or *platform.OnboardingRequ
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
// TODO(jsternberg): Should this check for a 201 explicitly?
|
// TODO(jsternberg): Should this check for a 201 explicitly?
|
||||||
if err := CheckError(resp); err != nil {
|
if err := CheckError(resp, true); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,13 +159,14 @@ func (s *SetupService) Generate(ctx context.Context, or *platform.OnboardingRequ
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bkt, err := oResp.Bucket.toPlatform()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &platform.OnboardingResults{
|
return &platform.OnboardingResults{
|
||||||
User: &oResp.User.User,
|
User: &oResp.User.User,
|
||||||
Auth: &oResp.Auth.Authorization,
|
Auth: &oResp.Auth.Authorization,
|
||||||
Org: &oResp.Organization.Organization,
|
Org: &oResp.Organization.Organization,
|
||||||
Bucket: &platform.Bucket{
|
Bucket: bkt,
|
||||||
ID: oResp.Bucket.ID,
|
|
||||||
Name: oResp.Bucket.Name,
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
"github.com/influxdata/platform/inmem"
|
||||||
|
platformtesting "github.com/influxdata/platform/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
|
||||||
|
t.Helper()
|
||||||
|
svc := inmem.NewService()
|
||||||
|
svc.IDGenerator = f.IDGenerator
|
||||||
|
svc.TokenGenerator = f.TokenGenerator
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := svc.PutOnboardingStatus(ctx, !f.IsOnboarding); err != nil {
|
||||||
|
t.Fatalf("failed to set new onboarding finished: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewSetupHandler()
|
||||||
|
handler.OnboardingService = svc
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
client := struct {
|
||||||
|
*SetupService
|
||||||
|
*Service
|
||||||
|
platform.BasicAuthService
|
||||||
|
}{
|
||||||
|
SetupService: &SetupService{
|
||||||
|
Addr: server.URL,
|
||||||
|
},
|
||||||
|
Service: &Service{
|
||||||
|
Addr: server.URL,
|
||||||
|
},
|
||||||
|
BasicAuthService: svc,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := server.Close
|
||||||
|
|
||||||
|
return client, done
|
||||||
|
}
|
||||||
|
func TestOnboardingService(t *testing.T) {
|
||||||
|
platformtesting.Generate(initOnboardingService, t)
|
||||||
|
}
|
|
@ -28,6 +28,7 @@ func initUserService(f platformtesting.UserFields, t *testing.T) (platform.UserS
|
||||||
client := UserService{
|
client := UserService{
|
||||||
Addr: server.URL,
|
Addr: server.URL,
|
||||||
}
|
}
|
||||||
|
|
||||||
done := server.Close
|
done := server.Close
|
||||||
|
|
||||||
return &client, done
|
return &client, done
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package inmem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HashCost is currently using bcrypt defaultCost
|
||||||
|
const HashCost = bcrypt.DefaultCost
|
||||||
|
|
||||||
|
// SetPassword stores the password hash associated with a user.
|
||||||
|
func (s *Service) SetPassword(ctx context.Context, name string, password string) error {
|
||||||
|
u, err := s.FindUser(ctx, platform.UserFilter{Name: &name})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.basicAuthKV.Store(u.ID.String(), hash)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComparePassword compares a provided password with the stored password hash.
|
||||||
|
func (s *Service) ComparePassword(ctx context.Context, name string, password string) error {
|
||||||
|
u, err := s.FindUser(ctx, platform.UserFilter{Name: &name})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hash, ok := s.basicAuthKV.Load(u.ID.String())
|
||||||
|
if !ok {
|
||||||
|
hash = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bcrypt.CompareHashAndPassword(hash.([]byte), []byte(password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
|
||||||
|
func (s *Service) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
|
||||||
|
if err := s.ComparePassword(ctx, name, old); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.SetPassword(ctx, name, new)
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package inmem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
platformtesting "github.com/influxdata/platform/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initBasicAuthService(f platformtesting.UserFields, t *testing.T) (platform.BasicAuthService, func()) {
|
||||||
|
s := NewService()
|
||||||
|
s.IDGenerator = f.IDGenerator
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, u := range f.Users {
|
||||||
|
if err := s.PutUser(ctx, u); err != nil {
|
||||||
|
t.Fatalf("failed to populate users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s, func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
platformtesting.BasicAuth(initBasicAuthService, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth_CompareAndSet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
platformtesting.CompareAndSetPassword(initBasicAuthService, t)
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
package inmem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
)
|
||||||
|
|
||||||
|
const onboardingKey = "onboarding_key"
|
||||||
|
|
||||||
|
var _ platform.OnboardingService = (*Service)(nil)
|
||||||
|
|
||||||
|
// IsOnboarding checks onboardingBucket
|
||||||
|
// to see if the onboarding key is true.
|
||||||
|
func (s *Service) IsOnboarding(ctx context.Context) (isOnboarding bool, err error) {
|
||||||
|
result, ok := s.onboardingKV.Load(onboardingKey)
|
||||||
|
isOnboarding = !ok || !result.(bool)
|
||||||
|
return isOnboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutOnboardingStatus will put the isOnboarding to storage
|
||||||
|
func (s *Service) PutOnboardingStatus(ctx context.Context, v bool) error {
|
||||||
|
s.onboardingKV.Store(onboardingKey, v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate OnboardingResults from onboarding request,
|
||||||
|
// update storage so this request will be disabled for the second run.
|
||||||
|
func (s *Service) Generate(ctx context.Context, req *platform.OnboardingRequest) (*platform.OnboardingResults, error) {
|
||||||
|
isOnboarding, err := s.IsOnboarding(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !isOnboarding {
|
||||||
|
return nil, &platform.Error{
|
||||||
|
Code: platform.EConflict,
|
||||||
|
Msg: "onboarding has already been completed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Password == "" {
|
||||||
|
return nil, &platform.Error{
|
||||||
|
Code: platform.EEmptyValue,
|
||||||
|
Msg: "password is empty",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.User == "" {
|
||||||
|
return nil, &platform.Error{
|
||||||
|
Code: platform.EEmptyValue,
|
||||||
|
Msg: "username is empty",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Org == "" {
|
||||||
|
return nil, &platform.Error{
|
||||||
|
Code: platform.EEmptyValue,
|
||||||
|
Msg: "org name is empty",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Bucket == "" {
|
||||||
|
return nil, &platform.Error{
|
||||||
|
Code: platform.EEmptyValue,
|
||||||
|
Msg: "bucket name is empty",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u := &platform.User{Name: req.User}
|
||||||
|
if err := s.CreateUser(ctx, u); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s.SetPassword(ctx, u.Name, req.Password); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
o := &platform.Organization{
|
||||||
|
Name: req.Org,
|
||||||
|
}
|
||||||
|
if err = s.CreateOrganization(ctx, o); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bucket := &platform.Bucket{
|
||||||
|
Name: req.Bucket,
|
||||||
|
Organization: o.Name,
|
||||||
|
OrganizationID: o.ID,
|
||||||
|
RetentionPeriod: time.Duration(req.RetentionPeriod) * time.Hour,
|
||||||
|
}
|
||||||
|
if err = s.CreateBucket(ctx, bucket); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
auth := &platform.Authorization{
|
||||||
|
User: u.Name,
|
||||||
|
UserID: u.ID,
|
||||||
|
Permissions: []platform.Permission{
|
||||||
|
platform.CreateUserPermission,
|
||||||
|
platform.DeleteUserPermission,
|
||||||
|
platform.Permission{
|
||||||
|
Resource: platform.OrganizationResource,
|
||||||
|
Action: platform.WriteAction,
|
||||||
|
},
|
||||||
|
platform.WriteBucketPermission(bucket.ID),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err = s.CreateAuthorization(ctx, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s.PutOnboardingStatus(ctx, true); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &platform.OnboardingResults{
|
||||||
|
User: u,
|
||||||
|
Org: o,
|
||||||
|
Bucket: bucket,
|
||||||
|
Auth: auth,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package inmem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
platformtesting "github.com/influxdata/platform/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
|
||||||
|
s := NewService()
|
||||||
|
s.IDGenerator = f.IDGenerator
|
||||||
|
s.TokenGenerator = f.TokenGenerator
|
||||||
|
ctx := context.TODO()
|
||||||
|
if err := s.PutOnboardingStatus(ctx, !f.IsOnboarding); err != nil {
|
||||||
|
t.Fatalf("failed to set new onboarding finished: %v", err)
|
||||||
|
}
|
||||||
|
return s, func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
platformtesting.Generate(initOnboardingService, t)
|
||||||
|
}
|
|
@ -22,6 +22,8 @@ type Service struct {
|
||||||
userResourceMappingKV sync.Map
|
userResourceMappingKV sync.Map
|
||||||
scraperTargetKV sync.Map
|
scraperTargetKV sync.Map
|
||||||
telegrafConfigKV sync.Map
|
telegrafConfigKV sync.Map
|
||||||
|
onboardingKV sync.Map
|
||||||
|
basicAuthKV sync.Map
|
||||||
|
|
||||||
TokenGenerator platform.TokenGenerator
|
TokenGenerator platform.TokenGenerator
|
||||||
IDGenerator platform.IDGenerator
|
IDGenerator platform.IDGenerator
|
||||||
|
|
|
@ -22,6 +22,12 @@ type OnboardingRequest struct {
|
||||||
|
|
||||||
// OnboardingService represents a service for the first run.
|
// OnboardingService represents a service for the first run.
|
||||||
type OnboardingService interface {
|
type OnboardingService interface {
|
||||||
|
BasicAuthService
|
||||||
|
BucketService
|
||||||
|
OrganizationService
|
||||||
|
UserService
|
||||||
|
AuthorizationService
|
||||||
|
|
||||||
// IsOnboarding determine if onboarding request is allowed.
|
// IsOnboarding determine if onboarding request is allowed.
|
||||||
IsOnboarding(ctx context.Context) (bool, error)
|
IsOnboarding(ctx context.Context) (bool, error)
|
||||||
// Generate OnboardingResults.
|
// Generate OnboardingResults.
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
package testing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/influxdata/platform"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BasicAuth test all the services for basic auth
|
||||||
|
func BasicAuth(
|
||||||
|
init func(UserFields, *testing.T) (platform.BasicAuthService, func()),
|
||||||
|
t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
user string
|
||||||
|
setPassword string
|
||||||
|
comparePassword string
|
||||||
|
}
|
||||||
|
type wants struct {
|
||||||
|
setErr error
|
||||||
|
compareErr error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
fields UserFields
|
||||||
|
args args
|
||||||
|
wants wants
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
fields: UserFields{
|
||||||
|
Users: []*platform.User{
|
||||||
|
{
|
||||||
|
Name: "user1",
|
||||||
|
ID: MustIDBase16(oneID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "happy path",
|
||||||
|
user: "user1",
|
||||||
|
setPassword: "hello",
|
||||||
|
comparePassword: "hello",
|
||||||
|
},
|
||||||
|
wants: wants{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fields: UserFields{
|
||||||
|
Users: []*platform.User{
|
||||||
|
{
|
||||||
|
Name: "user1",
|
||||||
|
ID: MustIDBase16(oneID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "happy path dont match",
|
||||||
|
user: "user1",
|
||||||
|
setPassword: "hello",
|
||||||
|
comparePassword: "world",
|
||||||
|
},
|
||||||
|
wants: wants{
|
||||||
|
compareErr: fmt.Errorf("crypto/bcrypt: hashedPassword is not the hash of the given password"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.args.name, func(t *testing.T) {
|
||||||
|
s, done := init(tt.fields, t)
|
||||||
|
defer done()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := s.SetPassword(ctx, tt.args.user, tt.args.setPassword)
|
||||||
|
|
||||||
|
if (err != nil && tt.wants.setErr == nil) || (err == nil && tt.wants.setErr != nil) {
|
||||||
|
t.Fatalf("expected SetPassword error %v got %v", tt.wants.setErr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if want, got := tt.wants.setErr.Error(), err.Error(); want != got {
|
||||||
|
t.Fatalf("expected SetPassword error %v got %v", want, got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.ComparePassword(ctx, tt.args.user, tt.args.comparePassword)
|
||||||
|
|
||||||
|
if (err != nil && tt.wants.compareErr == nil) || (err == nil && tt.wants.compareErr != nil) {
|
||||||
|
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if want, got := tt.wants.compareErr.Error(), err.Error(); want != got {
|
||||||
|
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSetPassword test
|
||||||
|
func CompareAndSetPassword(
|
||||||
|
init func(UserFields, *testing.T) (platform.BasicAuthService, func()),
|
||||||
|
t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
user string
|
||||||
|
old string
|
||||||
|
new string
|
||||||
|
}
|
||||||
|
type wants struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
fields UserFields
|
||||||
|
args args
|
||||||
|
wants wants
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
fields: UserFields{
|
||||||
|
Users: []*platform.User{
|
||||||
|
{
|
||||||
|
Name: "user1",
|
||||||
|
ID: MustIDBase16(oneID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "happy path",
|
||||||
|
user: "user1",
|
||||||
|
old: "hello",
|
||||||
|
new: "hello",
|
||||||
|
},
|
||||||
|
wants: wants{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.args.name, func(t *testing.T) {
|
||||||
|
s, done := init(tt.fields, t)
|
||||||
|
defer done()
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := s.SetPassword(ctx, tt.args.user, tt.args.old); err != nil {
|
||||||
|
t.Fatalf("unexpected error %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.CompareAndSetPassword(ctx, tt.args.user, tt.args.old, tt.args.new)
|
||||||
|
|
||||||
|
if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) {
|
||||||
|
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if want, got := tt.wants.err.Error(), err.Error(); want != got {
|
||||||
|
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -18,16 +18,9 @@ type OnboardingFields struct {
|
||||||
IsOnboarding bool
|
IsOnboarding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnBoardingNBasicAuthService includes onboarding service
|
|
||||||
// and basic auth service.
|
|
||||||
type OnBoardingNBasicAuthService interface {
|
|
||||||
platform.OnboardingService
|
|
||||||
platform.BasicAuthService
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate testing
|
// Generate testing
|
||||||
func Generate(
|
func Generate(
|
||||||
init func(OnboardingFields, *testing.T) (OnBoardingNBasicAuthService, func()),
|
init func(OnboardingFields, *testing.T) (platform.OnboardingService, func()),
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
) {
|
) {
|
||||||
type args struct {
|
type args struct {
|
||||||
|
|
1
user.go
1
user.go
|
@ -10,6 +10,7 @@ type User struct {
|
||||||
|
|
||||||
// UserService represents a service for managing user data.
|
// UserService represents a service for managing user data.
|
||||||
type UserService interface {
|
type UserService interface {
|
||||||
|
|
||||||
// Returns a single user by ID.
|
// Returns a single user by ID.
|
||||||
FindUserByID(ctx context.Context, id ID) (*User, error)
|
FindUserByID(ctx context.Context, id ID) (*User, error)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue