mirror of https://github.com/milvus-io/milvus.git
Fix gcp oauth token not cached (#20380)
Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com> Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com>pull/20400/head
parent
ad0cce8f70
commit
09ea38615e
|
@ -13,8 +13,9 @@ import (
|
|||
|
||||
// WrapHTTPTransport wraps http.Transport, add an auth header to support GCP native auth
|
||||
type WrapHTTPTransport struct {
|
||||
tokenSrc oauth2.TokenSource
|
||||
backend transport
|
||||
tokenSrc oauth2.TokenSource
|
||||
backend transport
|
||||
currentToken *oauth2.Token
|
||||
}
|
||||
|
||||
// transport abstracts http.Transport to simplify test
|
||||
|
@ -36,13 +37,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper
|
||||
// RoundTrip wraps original http.RoundTripper by Adding a Bearer token acquired from tokenSrc
|
||||
func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token, err := t.tokenSrc.Token()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to acquire token")
|
||||
// here Valid() means the token won't be expired in 10 sec
|
||||
// so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time
|
||||
if !t.currentToken.Valid() {
|
||||
var err error
|
||||
t.currentToken, err = t.tokenSrc.Token()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to acquire token")
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken)
|
||||
return t.backend.RoundTrip(req)
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
|||
ts.backend = &mockTransport{}
|
||||
ts.tokenSrc = &mockTokenSource{token: "mocktoken"}
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
t.Run("valid token ok", func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = ts.RoundTrip(req)
|
||||
|
@ -73,7 +73,8 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
|||
assert.Equal(t, "Bearer mocktoken", req.Header.Get("Authorization"))
|
||||
})
|
||||
|
||||
t.Run("get token failed", func(t *testing.T) {
|
||||
t.Run("invalid token, refresh failed", func(t *testing.T) {
|
||||
ts.currentToken = nil
|
||||
ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
@ -81,7 +82,17 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("call failed", func(t *testing.T) {
|
||||
t.Run("invalid token, refresh ok", func(t *testing.T) {
|
||||
ts.currentToken = nil
|
||||
ts.tokenSrc = &mockTokenSource{err: nil}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
_, err = ts.RoundTrip(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
ts.currentToken = &oauth2.Token{}
|
||||
t.Run("valid token, call failed", func(t *testing.T) {
|
||||
ts.backend = &mockTransport{err: errors.New("mock error")}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue