Fix gcp oauth token not cached (#20380)

Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com>

Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com>
pull/20400/head
shaoyue 2022-11-08 12:43:02 +08:00 committed by GitHub
parent ad0cce8f70
commit 09ea38615e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 10 deletions

View File

@ -13,8 +13,9 @@ import (
// WrapHTTPTransport wraps http.Transport, add an auth header to support GCP native auth
type WrapHTTPTransport struct {
tokenSrc oauth2.TokenSource
backend transport
tokenSrc oauth2.TokenSource
backend transport
currentToken *oauth2.Token
}
// transport abstracts http.Transport to simplify test
@ -36,13 +37,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) {
}, nil
}
// RoundTrip implements http.RoundTripper
// RoundTrip wraps original http.RoundTripper by Adding a Bearer token acquired from tokenSrc
func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenSrc.Token()
if err != nil {
return nil, errors.Wrap(err, "failed to acquire token")
// here Valid() means the token won't be expired in 10 sec
// so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time
if !t.currentToken.Valid() {
var err error
t.currentToken, err = t.tokenSrc.Token()
if err != nil {
return nil, errors.Wrap(err, "failed to acquire token")
}
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken)
return t.backend.RoundTrip(req)
}

View File

@ -65,7 +65,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
ts.backend = &mockTransport{}
ts.tokenSrc = &mockTokenSource{token: "mocktoken"}
t.Run("ok", func(t *testing.T) {
t.Run("valid token ok", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
_, err = ts.RoundTrip(req)
@ -73,7 +73,8 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
assert.Equal(t, "Bearer mocktoken", req.Header.Get("Authorization"))
})
t.Run("get token failed", func(t *testing.T) {
t.Run("invalid token, refresh failed", func(t *testing.T) {
ts.currentToken = nil
ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")}
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
@ -81,7 +82,17 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
assert.Error(t, err)
})
t.Run("call failed", func(t *testing.T) {
t.Run("invalid token, refresh ok", func(t *testing.T) {
ts.currentToken = nil
ts.tokenSrc = &mockTokenSource{err: nil}
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
_, err = ts.RoundTrip(req)
assert.NoError(t, err)
})
ts.currentToken = &oauth2.Token{}
t.Run("valid token, call failed", func(t *testing.T) {
ts.backend = &mockTransport{err: errors.New("mock error")}
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)