(feat/testing) add onboarding and basic auth service

pull/10616/head
Kelvin Wang 2018-10-03 10:51:14 -04:00
parent 3e54ef9f53
commit 63da5d1e9f
19 changed files with 589 additions and 264 deletions

View File

@ -0,0 +1,69 @@
package bolt
import (
"context"
bolt "github.com/coreos/bbolt"
"golang.org/x/crypto/bcrypt"
)
// SetPassword stores the password hash associated with a user.
func (c *Client) SetPassword(ctx context.Context, name string, password string) error {
return c.db.Update(func(tx *bolt.Tx) error {
return c.setPassword(ctx, tx, name, password)
})
}
// HashCost currently using the default cost of bcrypt
var HashCost = bcrypt.DefaultCost
func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
if err != nil {
return err
}
u, err := c.findUserByName(ctx, tx, name)
if err != nil {
return err
}
encodedID, err := u.ID.Encode()
if err != nil {
return err
}
return tx.Bucket(userpasswordBucket).Put(encodedID, hash)
}
// ComparePassword compares a provided password with the stored password hash.
func (c *Client) ComparePassword(ctx context.Context, name string, password string) error {
return c.db.View(func(tx *bolt.Tx) error {
return c.comparePassword(ctx, tx, name, password)
})
}
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
u, err := c.findUserByName(ctx, tx, name)
if err != nil {
return err
}
encodedID, err := u.ID.Encode()
if err != nil {
return err
}
hash := tx.Bucket(userpasswordBucket).Get(encodedID)
return bcrypt.CompareHashAndPassword(hash, []byte(password))
}
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
return c.db.Update(func(tx *bolt.Tx) error {
if err := c.comparePassword(ctx, tx, name, old); err != nil {
return err
}
return c.setPassword(ctx, tx, name, new)
})
}

View File

@ -0,0 +1,41 @@
package bolt_test
import (
"context"
"testing"
"github.com/influxdata/platform"
platformtesting "github.com/influxdata/platform/testing"
)
func initBasicAuthService(f platformtesting.UserFields, t *testing.T) (platform.BasicAuthService, func()) {
c, closeFn, err := NewTestClient()
if err != nil {
t.Fatalf("failed to create new bolt client: %v", err)
}
c.IDGenerator = f.IDGenerator
ctx := context.Background()
for _, u := range f.Users {
if err := c.PutUser(ctx, u); err != nil {
t.Fatalf("failed to populate users")
}
}
return c, func() {
defer closeFn()
for _, u := range f.Users {
if err := c.DeleteUser(ctx, u.ID); err != nil {
t.Logf("failed to remove users: %v", err)
}
}
}
}
func TestBasicAuth(t *testing.T) {
t.Parallel()
platformtesting.BasicAuth(initBasicAuthService, t)
}
func TestBasicAuth_CompareAndSet(t *testing.T) {
t.Parallel()
platformtesting.CompareAndSetPassword(initBasicAuthService, t)
}

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"testing" "testing"
"github.com/influxdata/platform"
platformtesting "github.com/influxdata/platform/testing" platformtesting "github.com/influxdata/platform/testing"
) )
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platformtesting.OnBoardingNBasicAuthService, func()) { func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
c, closeFn, err := NewTestClient() c, closeFn, err := NewTestClient()
if err != nil { if err != nil {
t.Fatalf("failed to create new bolt client: %v", err) t.Fatalf("failed to create new bolt client: %v", err)

View File

@ -7,7 +7,6 @@ import (
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/influxdata/platform" "github.com/influxdata/platform"
"golang.org/x/crypto/bcrypt"
) )
var ( var (
@ -341,63 +340,3 @@ func (c *Client) deleteUsersAuthorizations(ctx context.Context, tx *bolt.Tx, id
} }
return nil return nil
} }
// SetPassword stores the password hash associated with a user.
func (c *Client) SetPassword(ctx context.Context, name string, password string) error {
return c.db.Update(func(tx *bolt.Tx) error {
return c.setPassword(ctx, tx, name, password)
})
}
var HashCost = bcrypt.DefaultCost
func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
if err != nil {
return err
}
u, err := c.findUserByName(ctx, tx, name)
if err != nil {
return err
}
encodedID, err := u.ID.Encode()
if err != nil {
return err
}
return tx.Bucket(userpasswordBucket).Put(encodedID, hash)
}
// ComparePassword compares a provided password with the stored password hash.
func (c *Client) ComparePassword(ctx context.Context, name string, password string) error {
return c.db.View(func(tx *bolt.Tx) error {
return c.comparePassword(ctx, tx, name, password)
})
}
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
u, err := c.findUserByName(ctx, tx, name)
if err != nil {
return err
}
encodedID, err := u.ID.Encode()
if err != nil {
return err
}
hash := tx.Bucket(userpasswordBucket).Get(encodedID)
return bcrypt.CompareHashAndPassword(hash, []byte(password))
}
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
return c.db.Update(func(tx *bolt.Tx) error {
if err := c.comparePassword(ctx, tx, name, old); err != nil {
return err
}
return c.setPassword(ctx, tx, name, new)
})
}

View File

@ -29,6 +29,7 @@ func filterMappingsFn(filter platform.UserResourceMappingFilter) func(m *platfor
} }
} }
// FindUserResourceMappings returns a list of UserResourceMappings that match filter and the total count of matching mappings.
func (c *Client) FindUserResourceMappings(ctx context.Context, filter platform.UserResourceMappingFilter, opt ...platform.FindOptions) ([]*platform.UserResourceMapping, int, error) { func (c *Client) FindUserResourceMappings(ctx context.Context, filter platform.UserResourceMappingFilter, opt ...platform.FindOptions) ([]*platform.UserResourceMapping, int, error) {
ms := []*platform.UserResourceMapping{} ms := []*platform.UserResourceMapping{}
err := c.db.View(func(tx *bolt.Tx) error { err := c.db.View(func(tx *bolt.Tx) error {
@ -137,6 +138,7 @@ func (c *Client) uniqueUserResourceMapping(ctx context.Context, tx *bolt.Tx, m *
return len(v) == 0 return len(v) == 0
} }
// DeleteUserResourceMapping deletes a user resource mapping.
func (c *Client) DeleteUserResourceMapping(ctx context.Context, resourceID platform.ID, userID platform.ID) error { func (c *Client) DeleteUserResourceMapping(ctx context.Context, resourceID platform.ID, userID platform.ID) error {
return c.db.Update(func(tx *bolt.Tx) error { return c.db.Update(func(tx *bolt.Tx) error {
return c.deleteUserResourceMapping(ctx, tx, platform.UserResourceMappingFilter{ return c.deleteUserResourceMapping(ctx, tx, platform.UserResourceMappingFilter{

View File

@ -2,7 +2,6 @@ package bolt_test
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"github.com/influxdata/platform" "github.com/influxdata/platform"
@ -54,186 +53,3 @@ func TestUserService_FindUser(t *testing.T) {
func TestUserService_UpdateUser(t *testing.T) { func TestUserService_UpdateUser(t *testing.T) {
platformtesting.UpdateUser(initUserService, t) platformtesting.UpdateUser(initUserService, t)
} }
func TestBasicAuth(t *testing.T) {
type fields struct {
users []*platform.User
}
type args struct {
name string
user string
setPassword string
comparePassword string
}
type wants struct {
setErr error
compareErr error
}
tests := []struct {
fields fields
args args
wants wants
}{
{
fields: fields{
users: []*platform.User{
{
Name: "user1",
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
},
},
},
args: args{
name: "happy path",
user: "user1",
setPassword: "hello",
comparePassword: "hello",
},
wants: wants{},
},
{
fields: fields{
users: []*platform.User{
{
Name: "user1",
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
},
},
},
args: args{
name: "happy path dont match",
user: "user1",
setPassword: "hello",
comparePassword: "world",
},
wants: wants{
compareErr: fmt.Errorf("crypto/bcrypt: hashedPassword is not the hash of the given password"),
},
},
}
for _, tt := range tests {
t.Run(tt.args.name, func(t *testing.T) {
c, closeFn, err := NewTestClient()
if err != nil {
t.Fatalf("failed to create new bolt client: %v", err)
}
defer closeFn()
ctx := context.Background()
for _, user := range tt.fields.users {
if err := c.PutUser(ctx, user); err != nil {
t.Fatal(err)
return
}
}
err = c.SetPassword(ctx, tt.args.user, tt.args.setPassword)
if (err != nil && tt.wants.setErr == nil) || (err == nil && tt.wants.setErr != nil) {
t.Fatalf("expected SetPassword error %v got %v", tt.wants.setErr, err)
return
}
if err != nil {
if want, got := tt.wants.setErr.Error(), err.Error(); want != got {
t.Fatalf("expected SetPassword error %v got %v", want, got)
}
return
}
err = c.ComparePassword(ctx, tt.args.user, tt.args.comparePassword)
if (err != nil && tt.wants.compareErr == nil) || (err == nil && tt.wants.compareErr != nil) {
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
return
}
if err != nil {
if want, got := tt.wants.compareErr.Error(), err.Error(); want != got {
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
}
return
}
})
}
}
func TestBasicAuth_CompareAndSet(t *testing.T) {
type fields struct {
users []*platform.User
}
type args struct {
name string
user string
old string
new string
}
type wants struct {
err error
}
tests := []struct {
fields fields
args args
wants wants
}{
{
fields: fields{
users: []*platform.User{
{
Name: "user1",
ID: platformtesting.MustIDBase16("aaaaaaaaaaaaaaaa"),
},
},
},
args: args{
name: "happy path",
user: "user1",
old: "hello",
new: "hello",
},
wants: wants{},
},
}
for _, tt := range tests {
t.Run(tt.args.name, func(t *testing.T) {
c, closeFn, err := NewTestClient()
if err != nil {
t.Fatalf("failed to create new bolt client: %v", err)
}
defer closeFn()
ctx := context.Background()
for _, user := range tt.fields.users {
if err := c.PutUser(ctx, user); err != nil {
t.Fatal(err)
return
}
}
if err := c.SetPassword(ctx, tt.args.user, tt.args.old); err != nil {
t.Fatalf("unexpected error %v", err)
return
}
err = c.CompareAndSetPassword(ctx, tt.args.user, tt.args.old, tt.args.new)
if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) {
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
return
}
if err != nil {
if want, got := tt.wants.err.Error(), err.Error(); want != got {
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
}
return
}
})
}
}

View File

@ -52,7 +52,8 @@ func CheckErrorStatus(code int, res *http.Response) error {
// be determined in that way, it will create a generic error message. // be determined in that way, it will create a generic error message.
// //
// If there is no error, then this returns nil. // If there is no error, then this returns nil.
func CheckError(resp *http.Response) error { // Add an optional isPlatformError, to do decode with platform.Error
func CheckError(resp *http.Response, isPlatformError ...bool) (err error) {
switch resp.StatusCode / 100 { switch resp.StatusCode / 100 {
case 4, 5: case 4, 5:
// We will attempt to parse this error outside of this block. // We will attempt to parse this error outside of this block.
@ -62,7 +63,15 @@ func CheckError(resp *http.Response) error {
// TODO(jsternberg): Figure out what to do here? // TODO(jsternberg): Figure out what to do here?
return kerrors.InternalErrorf("unexpected status code: %d %s", resp.StatusCode, resp.Status) return kerrors.InternalErrorf("unexpected status code: %d %s", resp.StatusCode, resp.Status)
} }
if len(isPlatformError) > 0 && isPlatformError[0] {
pe := new(platform.Error)
parseErr := json.NewDecoder(resp.Body).Decode(pe)
if parseErr != nil {
return parseErr
}
err = pe
return err
}
// Attempt to read the X-Influx-Error header with the message. // Attempt to read the X-Influx-Error header with the message.
if errMsg := resp.Header.Get(ErrorHeader); errMsg != "" { if errMsg := resp.Header.Get(ErrorHeader); errMsg != "" {
// Parse the reference number as an integer. If we cannot parse it, // Parse the reference number as an integer. If we cannot parse it,

View File

@ -148,8 +148,9 @@ func (s *SetupService) Generate(ctx context.Context, or *platform.OnboardingRequ
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close()
// TODO(jsternberg): Should this check for a 201 explicitly? // TODO(jsternberg): Should this check for a 201 explicitly?
if err := CheckError(resp); err != nil { if err := CheckError(resp, true); err != nil {
return nil, err return nil, err
} }
@ -158,13 +159,14 @@ func (s *SetupService) Generate(ctx context.Context, or *platform.OnboardingRequ
return nil, err return nil, err
} }
bkt, err := oResp.Bucket.toPlatform()
if err != nil {
return nil, err
}
return &platform.OnboardingResults{ return &platform.OnboardingResults{
User: &oResp.User.User, User: &oResp.User.User,
Auth: &oResp.Auth.Authorization, Auth: &oResp.Auth.Authorization,
Org: &oResp.Organization.Organization, Org: &oResp.Organization.Organization,
Bucket: &platform.Bucket{ Bucket: bkt,
ID: oResp.Bucket.ID,
Name: oResp.Bucket.Name,
},
}, nil }, nil
} }

47
http/onboarding_test.go Normal file
View File

@ -0,0 +1,47 @@
package http
import (
"context"
"net/http/httptest"
"testing"
"github.com/influxdata/platform"
"github.com/influxdata/platform/inmem"
platformtesting "github.com/influxdata/platform/testing"
)
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
t.Helper()
svc := inmem.NewService()
svc.IDGenerator = f.IDGenerator
svc.TokenGenerator = f.TokenGenerator
ctx := context.Background()
if err := svc.PutOnboardingStatus(ctx, !f.IsOnboarding); err != nil {
t.Fatalf("failed to set new onboarding finished: %v", err)
}
handler := NewSetupHandler()
handler.OnboardingService = svc
server := httptest.NewServer(handler)
client := struct {
*SetupService
*Service
platform.BasicAuthService
}{
SetupService: &SetupService{
Addr: server.URL,
},
Service: &Service{
Addr: server.URL,
},
BasicAuthService: svc,
}
done := server.Close
return client, done
}
func TestOnboardingService(t *testing.T) {
platformtesting.Generate(initOnboardingService, t)
}

View File

@ -28,6 +28,7 @@ func initUserService(f platformtesting.UserFields, t *testing.T) (platform.UserS
client := UserService{ client := UserService{
Addr: server.URL, Addr: server.URL,
} }
done := server.Close done := server.Close
return &client, done return &client, done

49
inmem/basic_auth.go Normal file
View File

@ -0,0 +1,49 @@
package inmem
import (
"context"
"github.com/influxdata/platform"
"golang.org/x/crypto/bcrypt"
)
// HashCost is currently using bcrypt defaultCost
const HashCost = bcrypt.DefaultCost
// SetPassword stores the password hash associated with a user.
func (s *Service) SetPassword(ctx context.Context, name string, password string) error {
u, err := s.FindUser(ctx, platform.UserFilter{Name: &name})
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost)
if err != nil {
return err
}
s.basicAuthKV.Store(u.ID.String(), hash)
return nil
}
// ComparePassword compares a provided password with the stored password hash.
func (s *Service) ComparePassword(ctx context.Context, name string, password string) error {
u, err := s.FindUser(ctx, platform.UserFilter{Name: &name})
if err != nil {
return err
}
hash, ok := s.basicAuthKV.Load(u.ID.String())
if !ok {
hash = []byte{}
}
return bcrypt.CompareHashAndPassword(hash.([]byte), []byte(password))
}
// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
func (s *Service) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
if err := s.ComparePassword(ctx, name, old); err != nil {
return err
}
return s.SetPassword(ctx, name, new)
}

31
inmem/basic_auth_test.go Normal file
View File

@ -0,0 +1,31 @@
package inmem
import (
"context"
"testing"
"github.com/influxdata/platform"
platformtesting "github.com/influxdata/platform/testing"
)
func initBasicAuthService(f platformtesting.UserFields, t *testing.T) (platform.BasicAuthService, func()) {
s := NewService()
s.IDGenerator = f.IDGenerator
ctx := context.Background()
for _, u := range f.Users {
if err := s.PutUser(ctx, u); err != nil {
t.Fatalf("failed to populate users")
}
}
return s, func() {}
}
func TestBasicAuth(t *testing.T) {
t.Parallel()
platformtesting.BasicAuth(initBasicAuthService, t)
}
func TestBasicAuth_CompareAndSet(t *testing.T) {
t.Parallel()
platformtesting.CompareAndSetPassword(initBasicAuthService, t)
}

121
inmem/onboarding.go Normal file
View File

@ -0,0 +1,121 @@
package inmem
import (
"context"
"time"
"github.com/influxdata/platform"
)
const onboardingKey = "onboarding_key"
var _ platform.OnboardingService = (*Service)(nil)
// IsOnboarding checks onboardingBucket
// to see if the onboarding key is true.
func (s *Service) IsOnboarding(ctx context.Context) (isOnboarding bool, err error) {
result, ok := s.onboardingKV.Load(onboardingKey)
isOnboarding = !ok || !result.(bool)
return isOnboarding, nil
}
// PutOnboardingStatus will put the isOnboarding to storage
func (s *Service) PutOnboardingStatus(ctx context.Context, v bool) error {
s.onboardingKV.Store(onboardingKey, v)
return nil
}
// Generate OnboardingResults from onboarding request,
// update storage so this request will be disabled for the second run.
func (s *Service) Generate(ctx context.Context, req *platform.OnboardingRequest) (*platform.OnboardingResults, error) {
isOnboarding, err := s.IsOnboarding(ctx)
if err != nil {
return nil, err
}
if !isOnboarding {
return nil, &platform.Error{
Code: platform.EConflict,
Msg: "onboarding has already been completed",
}
}
if req.Password == "" {
return nil, &platform.Error{
Code: platform.EEmptyValue,
Msg: "password is empty",
}
}
if req.User == "" {
return nil, &platform.Error{
Code: platform.EEmptyValue,
Msg: "username is empty",
}
}
if req.Org == "" {
return nil, &platform.Error{
Code: platform.EEmptyValue,
Msg: "org name is empty",
}
}
if req.Bucket == "" {
return nil, &platform.Error{
Code: platform.EEmptyValue,
Msg: "bucket name is empty",
}
}
u := &platform.User{Name: req.User}
if err := s.CreateUser(ctx, u); err != nil {
return nil, err
}
if err = s.SetPassword(ctx, u.Name, req.Password); err != nil {
return nil, err
}
o := &platform.Organization{
Name: req.Org,
}
if err = s.CreateOrganization(ctx, o); err != nil {
return nil, err
}
bucket := &platform.Bucket{
Name: req.Bucket,
Organization: o.Name,
OrganizationID: o.ID,
RetentionPeriod: time.Duration(req.RetentionPeriod) * time.Hour,
}
if err = s.CreateBucket(ctx, bucket); err != nil {
return nil, err
}
auth := &platform.Authorization{
User: u.Name,
UserID: u.ID,
Permissions: []platform.Permission{
platform.CreateUserPermission,
platform.DeleteUserPermission,
platform.Permission{
Resource: platform.OrganizationResource,
Action: platform.WriteAction,
},
platform.WriteBucketPermission(bucket.ID),
},
}
if err = s.CreateAuthorization(ctx, auth); err != nil {
return nil, err
}
if err = s.PutOnboardingStatus(ctx, true); err != nil {
return nil, err
}
return &platform.OnboardingResults{
User: u,
Org: o,
Bucket: bucket,
Auth: auth,
}, nil
}

24
inmem/onboarding_test.go Normal file
View File

@ -0,0 +1,24 @@
package inmem
import (
"context"
"testing"
"github.com/influxdata/platform"
platformtesting "github.com/influxdata/platform/testing"
)
func initOnboardingService(f platformtesting.OnboardingFields, t *testing.T) (platform.OnboardingService, func()) {
s := NewService()
s.IDGenerator = f.IDGenerator
s.TokenGenerator = f.TokenGenerator
ctx := context.TODO()
if err := s.PutOnboardingStatus(ctx, !f.IsOnboarding); err != nil {
t.Fatalf("failed to set new onboarding finished: %v", err)
}
return s, func() {}
}
func TestGenerate(t *testing.T) {
platformtesting.Generate(initOnboardingService, t)
}

View File

@ -22,6 +22,8 @@ type Service struct {
userResourceMappingKV sync.Map userResourceMappingKV sync.Map
scraperTargetKV sync.Map scraperTargetKV sync.Map
telegrafConfigKV sync.Map telegrafConfigKV sync.Map
onboardingKV sync.Map
basicAuthKV sync.Map
TokenGenerator platform.TokenGenerator TokenGenerator platform.TokenGenerator
IDGenerator platform.IDGenerator IDGenerator platform.IDGenerator

View File

@ -22,6 +22,12 @@ type OnboardingRequest struct {
// OnboardingService represents a service for the first run. // OnboardingService represents a service for the first run.
type OnboardingService interface { type OnboardingService interface {
BasicAuthService
BucketService
OrganizationService
UserService
AuthorizationService
// IsOnboarding determine if onboarding request is allowed. // IsOnboarding determine if onboarding request is allowed.
IsOnboarding(ctx context.Context) (bool, error) IsOnboarding(ctx context.Context) (bool, error)
// Generate OnboardingResults. // Generate OnboardingResults.

171
testing/basic_auth.go Normal file
View File

@ -0,0 +1,171 @@
package testing
import (
"context"
"fmt"
"testing"
"github.com/influxdata/platform"
)
// BasicAuth test all the services for basic auth
func BasicAuth(
init func(UserFields, *testing.T) (platform.BasicAuthService, func()),
t *testing.T) {
type args struct {
name string
user string
setPassword string
comparePassword string
}
type wants struct {
setErr error
compareErr error
}
tests := []struct {
fields UserFields
args args
wants wants
}{
{
fields: UserFields{
Users: []*platform.User{
{
Name: "user1",
ID: MustIDBase16(oneID),
},
},
},
args: args{
name: "happy path",
user: "user1",
setPassword: "hello",
comparePassword: "hello",
},
wants: wants{},
},
{
fields: UserFields{
Users: []*platform.User{
{
Name: "user1",
ID: MustIDBase16(oneID),
},
},
},
args: args{
name: "happy path dont match",
user: "user1",
setPassword: "hello",
comparePassword: "world",
},
wants: wants{
compareErr: fmt.Errorf("crypto/bcrypt: hashedPassword is not the hash of the given password"),
},
},
}
for _, tt := range tests {
t.Run(tt.args.name, func(t *testing.T) {
s, done := init(tt.fields, t)
defer done()
ctx := context.Background()
err := s.SetPassword(ctx, tt.args.user, tt.args.setPassword)
if (err != nil && tt.wants.setErr == nil) || (err == nil && tt.wants.setErr != nil) {
t.Fatalf("expected SetPassword error %v got %v", tt.wants.setErr, err)
return
}
if err != nil {
if want, got := tt.wants.setErr.Error(), err.Error(); want != got {
t.Fatalf("expected SetPassword error %v got %v", want, got)
}
return
}
err = s.ComparePassword(ctx, tt.args.user, tt.args.comparePassword)
if (err != nil && tt.wants.compareErr == nil) || (err == nil && tt.wants.compareErr != nil) {
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
return
}
if err != nil {
if want, got := tt.wants.compareErr.Error(), err.Error(); want != got {
t.Fatalf("expected ComparePassword error %v got %v", tt.wants.compareErr, err)
}
return
}
})
}
}
// CompareAndSetPassword test
func CompareAndSetPassword(
init func(UserFields, *testing.T) (platform.BasicAuthService, func()),
t *testing.T) {
type args struct {
name string
user string
old string
new string
}
type wants struct {
err error
}
tests := []struct {
fields UserFields
args args
wants wants
}{
{
fields: UserFields{
Users: []*platform.User{
{
Name: "user1",
ID: MustIDBase16(oneID),
},
},
},
args: args{
name: "happy path",
user: "user1",
old: "hello",
new: "hello",
},
wants: wants{},
},
}
for _, tt := range tests {
t.Run(tt.args.name, func(t *testing.T) {
s, done := init(tt.fields, t)
defer done()
ctx := context.Background()
if err := s.SetPassword(ctx, tt.args.user, tt.args.old); err != nil {
t.Fatalf("unexpected error %v", err)
return
}
err := s.CompareAndSetPassword(ctx, tt.args.user, tt.args.old, tt.args.new)
if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) {
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
return
}
if err != nil {
if want, got := tt.wants.err.Error(), err.Error(); want != got {
t.Fatalf("expected CompareAndSetPassword error %v got %v", tt.wants.err, err)
}
return
}
})
}
}

View File

@ -18,16 +18,9 @@ type OnboardingFields struct {
IsOnboarding bool IsOnboarding bool
} }
// OnBoardingNBasicAuthService includes onboarding service
// and basic auth service.
type OnBoardingNBasicAuthService interface {
platform.OnboardingService
platform.BasicAuthService
}
// Generate testing // Generate testing
func Generate( func Generate(
init func(OnboardingFields, *testing.T) (OnBoardingNBasicAuthService, func()), init func(OnboardingFields, *testing.T) (platform.OnboardingService, func()),
t *testing.T, t *testing.T,
) { ) {
type args struct { type args struct {

View File

@ -10,6 +10,7 @@ type User struct {
// UserService represents a service for managing user data. // UserService represents a service for managing user data.
type UserService interface { type UserService interface {
// Returns a single user by ID. // Returns a single user by ID.
FindUserByID(ctx context.Context, id ID) (*User, error) FindUserByID(ctx context.Context, id ID) (*User, error)