From 09ea38615eac4f60b9589e233583921aede6fca4 Mon Sep 17 00:00:00 2001 From: shaoyue Date: Tue, 8 Nov 2022 12:43:02 +0800 Subject: [PATCH] Fix gcp oauth token not cached (#20380) Signed-off-by: shaoyue.chen Signed-off-by: shaoyue.chen --- internal/storage/gcp/gcp.go | 20 +++++++++++++------- internal/storage/gcp/gcp_test.go | 17 ++++++++++++++--- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/internal/storage/gcp/gcp.go b/internal/storage/gcp/gcp.go index e849adcef8..e4be612a37 100644 --- a/internal/storage/gcp/gcp.go +++ b/internal/storage/gcp/gcp.go @@ -13,8 +13,9 @@ import ( // WrapHTTPTransport wraps http.Transport, add an auth header to support GCP native auth type WrapHTTPTransport struct { - tokenSrc oauth2.TokenSource - backend transport + tokenSrc oauth2.TokenSource + backend transport + currentToken *oauth2.Token } // transport abstracts http.Transport to simplify test @@ -36,13 +37,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) { }, nil } -// RoundTrip implements http.RoundTripper +// RoundTrip wraps original http.RoundTripper by Adding a Bearer token acquired from tokenSrc func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - token, err := t.tokenSrc.Token() - if err != nil { - return nil, errors.Wrap(err, "failed to acquire token") + // here Valid() means the token won't be expired in 10 sec + // so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time + if !t.currentToken.Valid() { + var err error + t.currentToken, err = t.tokenSrc.Token() + if err != nil { + return nil, errors.Wrap(err, "failed to acquire token") + } } - req.Header.Set("Authorization", "Bearer "+token.AccessToken) + req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken) return t.backend.RoundTrip(req) } diff --git a/internal/storage/gcp/gcp_test.go b/internal/storage/gcp/gcp_test.go index e22ab72d4f..6433b5e909 100644 --- a/internal/storage/gcp/gcp_test.go +++ b/internal/storage/gcp/gcp_test.go @@ -65,7 +65,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { ts.backend = &mockTransport{} ts.tokenSrc = &mockTokenSource{token: "mocktoken"} - t.Run("ok", func(t *testing.T) { + t.Run("valid token ok", func(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) assert.NoError(t, err) _, err = ts.RoundTrip(req) @@ -73,7 +73,8 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { assert.Equal(t, "Bearer mocktoken", req.Header.Get("Authorization")) }) - t.Run("get token failed", func(t *testing.T) { + t.Run("invalid token, refresh failed", func(t *testing.T) { + ts.currentToken = nil ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")} req, err := http.NewRequest("GET", "http://example.com", nil) assert.NoError(t, err) @@ -81,7 +82,17 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { assert.Error(t, err) }) - t.Run("call failed", func(t *testing.T) { + t.Run("invalid token, refresh ok", func(t *testing.T) { + ts.currentToken = nil + ts.tokenSrc = &mockTokenSource{err: nil} + req, err := http.NewRequest("GET", "http://example.com", nil) + assert.NoError(t, err) + _, err = ts.RoundTrip(req) + assert.NoError(t, err) + }) + + ts.currentToken = &oauth2.Token{} + t.Run("valid token, call failed", func(t *testing.T) { ts.backend = &mockTransport{err: errors.New("mock error")} req, err := http.NewRequest("GET", "http://example.com", nil) assert.NoError(t, err)