mirror of https://github.com/milvus-io/milvus.git
Fix gcp oauth token not cached (#20380)
Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com> Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com>pull/20400/head
parent
ad0cce8f70
commit
09ea38615e
internal/storage/gcp
|
@ -15,6 +15,7 @@ import (
|
||||||
type WrapHTTPTransport struct {
|
type WrapHTTPTransport struct {
|
||||||
tokenSrc oauth2.TokenSource
|
tokenSrc oauth2.TokenSource
|
||||||
backend transport
|
backend transport
|
||||||
|
currentToken *oauth2.Token
|
||||||
}
|
}
|
||||||
|
|
||||||
// transport abstracts http.Transport to simplify test
|
// transport abstracts http.Transport to simplify test
|
||||||
|
@ -36,13 +37,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements http.RoundTripper
|
// RoundTrip wraps original http.RoundTripper by Adding a Bearer token acquired from tokenSrc
|
||||||
func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
token, err := t.tokenSrc.Token()
|
// here Valid() means the token won't be expired in 10 sec
|
||||||
|
// so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time
|
||||||
|
if !t.currentToken.Valid() {
|
||||||
|
var err error
|
||||||
|
t.currentToken, err = t.tokenSrc.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to acquire token")
|
return nil, errors.Wrap(err, "failed to acquire token")
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken)
|
||||||
return t.backend.RoundTrip(req)
|
return t.backend.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
||||||
ts.backend = &mockTransport{}
|
ts.backend = &mockTransport{}
|
||||||
ts.tokenSrc = &mockTokenSource{token: "mocktoken"}
|
ts.tokenSrc = &mockTokenSource{token: "mocktoken"}
|
||||||
|
|
||||||
t.Run("ok", func(t *testing.T) {
|
t.Run("valid token ok", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_, err = ts.RoundTrip(req)
|
_, err = ts.RoundTrip(req)
|
||||||
|
@ -73,7 +73,8 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
||||||
assert.Equal(t, "Bearer mocktoken", req.Header.Get("Authorization"))
|
assert.Equal(t, "Bearer mocktoken", req.Header.Get("Authorization"))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("get token failed", func(t *testing.T) {
|
t.Run("invalid token, refresh failed", func(t *testing.T) {
|
||||||
|
ts.currentToken = nil
|
||||||
ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")}
|
ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")}
|
||||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -81,7 +82,17 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("call failed", func(t *testing.T) {
|
t.Run("invalid token, refresh ok", func(t *testing.T) {
|
||||||
|
ts.currentToken = nil
|
||||||
|
ts.tokenSrc = &mockTokenSource{err: nil}
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = ts.RoundTrip(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
ts.currentToken = &oauth2.Token{}
|
||||||
|
t.Run("valid token, call failed", func(t *testing.T) {
|
||||||
ts.backend = &mockTransport{err: errors.New("mock error")}
|
ts.backend = &mockTransport{err: errors.New("mock error")}
|
||||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
Loading…
Reference in New Issue