Make GCS OAuth token thread-safe (#22714)

Signed-off-by: huanghaoyuan <haoyuan.huang@zilliz.com>
pull/22779/head
huanghaoyuanhhy 2023-03-14 18:09:54 +08:00 committed by GitHub
parent 732986aa04
commit 024beddfe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 9 deletions

2
go.mod
View File

@ -43,7 +43,7 @@ require (
go.etcd.io/etcd/api/v3 v3.5.5
go.etcd.io/etcd/client/v3 v3.5.5
go.etcd.io/etcd/server/v3 v3.5.5
go.uber.org/atomic v1.7.0
go.uber.org/atomic v1.10.0
go.uber.org/automaxprocs v1.4.0
go.uber.org/zap v1.17.0
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4

2
go.sum
View File

@ -884,6 +884,8 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=

View File

@ -7,6 +7,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"go.uber.org/atomic"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
@ -15,7 +16,7 @@ import (
type WrapHTTPTransport struct {
tokenSrc oauth2.TokenSource
backend transport
currentToken *oauth2.Token
currentToken atomic.Pointer[oauth2.Token]
}
// transport abstracts http.Transport to simplify test
@ -41,14 +42,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) {
func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// here Valid() means the token won't be expired in 10 sec
// so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time
if !t.currentToken.Valid() {
var err error
t.currentToken, err = t.tokenSrc.Token()
currentToken := t.currentToken.Load()
if currentToken.Valid() {
req.Header.Set("Authorization", "Bearer "+currentToken.AccessToken)
} else {
newToken, err := t.tokenSrc.Token()
if err != nil {
return nil, errors.Wrap(err, "failed to acquire token")
}
t.currentToken.Store(newToken)
req.Header.Set("Authorization", "Bearer "+newToken.AccessToken)
}
req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken)
return t.backend.RoundTrip(req)
}

View File

@ -75,7 +75,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
})
t.Run("invalid token, refresh failed", func(t *testing.T) {
ts.currentToken = nil
ts.currentToken.Store(nil)
ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")}
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
@ -84,7 +84,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
})
t.Run("invalid token, refresh ok", func(t *testing.T) {
ts.currentToken = nil
ts.currentToken.Store(nil)
ts.tokenSrc = &mockTokenSource{err: nil}
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
@ -92,7 +92,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) {
assert.NoError(t, err)
})
ts.currentToken = &oauth2.Token{}
ts.currentToken.Store(&oauth2.Token{})
t.Run("valid token, call failed", func(t *testing.T) {
ts.backend = &mockTransport{err: errors.New("mock error")}
req, err := http.NewRequest("GET", "http://example.com", nil)