diff --git a/oauth2/code_exchange_test.go b/oauth2/code_exchange_test.go index a2ac9f1fd..230786bda 100644 --- a/oauth2/code_exchange_test.go +++ b/oauth2/code_exchange_test.go @@ -1,7 +1,6 @@ package oauth2 import ( - "bytes" "context" "encoding/base64" "encoding/json" @@ -12,6 +11,7 @@ import ( "time" clog "github.com/influxdata/chronograf/log" + "github.com/stretchr/testify/require" ) func Test_CodeExchangeCSRF_AuthCodeURL(t *testing.T) { @@ -37,13 +37,9 @@ func Test_CodeExchangeCSRF_AuthCodeURL(t *testing.T) { // create AuthCodeURL with code exchange without PKCE codeExchange := NewCodeExchange(false, "") authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux) - if err != nil { - t.Fatal("Error generating AuthCodeURL: ", err) - } + require.Nil(t, err, "Unable to generate AuthCodeURL") authCodeURL, err := url.ParseRequestURI(authCodeURLString) - if err != nil { - t.Fatal("Error in AuthCodeURL format: ", err) - } + require.Nil(t, err, "AuthCodeURL format invalid") expectedParams := map[string]string{ "access_type": "online", @@ -104,24 +100,14 @@ func Test_CodeExchangeCSRF_ExchangeCodeForToken(t *testing.T) { // create AuthCodeURL using CodeExchange with PKCE codeExchange := simpleTokenExchange authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux) - if err != nil { - t.Fatal("Unable to generate AuthCodeURL ", err) - } + require.Nil(t, err, "Unable to generate AuthCodeURL") authCodeURL, err := url.ParseRequestURI(authCodeURLString) - if err != nil { - t.Fatal("Error in AuthCodeURL format: ", err) - } + require.Nil(t, err, "AuthCodeURL format invalid") state := authCodeURL.Query().Get("state") token, err := codeExchange.ExchangeCodeForToken(context.Background(), state, "any code", authMux) - if err != nil { - t.Fatal("ExchangeCodeForToken ends with error: ", err) - } - if token == nil { - t.Fatal("No token received!") - } - if token.AccessToken != testToken { - t.Errorf("Token expected: %s got: %s", testToken, token.AccessToken) - } + require.Nil(t, err, "ExchangeCodeForToken ends with error") + require.NotNil(t, token, "No token received") + require.Equal(t, testToken, token.AccessToken) expectedParams := []string{"code"} for _, key := range expectedParams { foundVal := exchangeUrlValues.Get(key) @@ -155,13 +141,9 @@ func Test_CodeExchangePKCE_AuthCodeURL(t *testing.T) { // create AuthCodeURL using CodeExchange with PKCE codeExchange := NewCodeExchange(true, "secret") authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux) - if err != nil { - t.Fatal("Error generating AuthCodeURL: ", err) - } + require.Nil(t, err, "Unable to generate AuthCodeURL") authCodeURL, err := url.ParseRequestURI(authCodeURLString) - if err != nil { - t.Fatal("Error in AuthCodeURL format: ", err) - } + require.Nil(t, err, "Invalid AuthCodeURL format") expectedParams := map[string]string{ "access_type": "online", @@ -194,16 +176,10 @@ func Test_CodeExchangePKCE_EncryptDecrypt(t *testing.T) { codeExchange := &CodeExchangePKCE{Secret: "hardtoguess"} plain := []byte("this is a test") encrypted, err := codeExchange.Encrypt([]byte(plain)) - if err != nil { - t.Fatal("Unable to encrypt plain text ", err) - } + require.Nil(t, err, "Unable to encrypt plain text ") decrypted, err := codeExchange.Decrypt(encrypted) - if err != nil { - t.Fatal("Unable to encrypt plain text ", err) - } - if !bytes.Equal(plain, decrypted) { - t.Errorf("Decrypted data are different; expected: %v , got: %v", plain, decrypted) - } + require.Nil(t, err, "Unable to decrypt plain text") + require.Equal(t, plain, decrypted) } func Test_CodeExchangePKCE_ExchangeCodeForToken(t *testing.T) { @@ -242,32 +218,18 @@ func Test_CodeExchangePKCE_ExchangeCodeForToken(t *testing.T) { // create AuthCodeURL using CodeExchange with PKCE codeExchange := CodeExchangePKCE{Secret: secret} authCodeURLString, err := codeExchange.AuthCodeURL(context.Background(), authMux) - if err != nil { - t.Fatal("Unable to generate AuthCodeURL ", err) - } + require.Nil(t, err, "Unable to generate AuthCodeURL") authCodeURL, err := url.ParseRequestURI(authCodeURLString) - if err != nil { - t.Fatal("Error in AuthCodeURL format: ", err) - } + require.Nil(t, err, "Invalid AuthCodeURL format") state := authCodeURL.Query().Get("state") token, err := codeExchange.ExchangeCodeForToken(context.Background(), state, "any code", authMux) - if err != nil { - t.Fatal("ExchangeCodeForToken ends with error: ", err) - } - if token == nil { - t.Fatal("No token received!") - } - if token.AccessToken != testToken { - t.Errorf("Token expected: %s got: %s", testToken, token.AccessToken) - } + require.Nil(t, err) + require.NotNil(t, token) + require.Equal(t, testToken, token.AccessToken) stateData, err := jwt.ValidPrincipal(context.Background(), Token(state), TenMinutes) - if err != nil { - t.Fatalf("invalid OAuth state received: %v", err.Error()) - } + require.Nil(t, err, "invalid OAuth state received") codeVerifierBytes, err := codeExchange.Decrypt(stateData.Subject) - if err != nil { - t.Fatalf("invalid OAuth state received: %v", err.Error()) - } + require.Nil(t, err, "unable to decrypt code verifier") codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBytes) expectedParams := map[string]string{ "code": "",