tweak a couple of dependency versions

Signed-off-by: Steve Kriss <krisss@vmware.com>
pull/1231/head
Steve Kriss 2019-02-22 09:50:08 -07:00
parent 58e471bda0
commit f8548e1ca1
37 changed files with 1383 additions and 599 deletions

17
Gopkg.lock generated
View File

@ -29,7 +29,7 @@
version = "v11.3.0-beta"
[[projects]]
digest = "1:8fc5113397520cfaddc7d0db0d2195528592866dfe1e17e6c7b0b7243533f01a"
digest = "1:b825d8578481c8877ff3b9a3654d77a48577cc33e65f33c3678d7e3f134bf73d"
name = "github.com/Azure/go-autorest"
packages = [
"autorest",
@ -38,9 +38,11 @@
"autorest/date",
"autorest/to",
"autorest/validation",
"version",
]
pruneopts = "NUT"
revision = "1ff28809256a84bb6966640ff3d0371af82ccba4"
revision = "bca49d5b51a50dc5bb17bbf6204c711c6dbded06"
version = "v10.14.0"
[[projects]]
digest = "1:f41188abdb95b92995643a927f5bdd208389822a8e1aba00d85633ae51b85c85"
@ -281,11 +283,12 @@
revision = "0b12d6b5"
[[projects]]
digest = "1:0243cffa4a3410f161ee613dfdd903a636d07e838a42d341da95d81f42cd1d41"
digest = "1:8e36686e8b139f8fe240c1d5cf3a145bc675c22ff8e707857cdd3ae17b00d728"
name = "github.com/json-iterator/go"
packages = ["."]
pruneopts = "NUT"
revision = "f2b4162afba35581b6d4a50d3b8f34e33c144682"
revision = "1624edc4454b8682399def8740d46db5e4362ba4"
version = "v1.1.5"
[[projects]]
digest = "1:5985ef4caf91ece5d54817c11ea25f182697534f8ae6521eadcd628c142ac4b6"
@ -312,12 +315,12 @@
version = "1.0.3"
[[projects]]
digest = "1:314a5881fab303a80d6d2e35a77000f2224bb50f09ef63a9aa4c1f9eaef985d8"
digest = "1:c6aca19413b13dc59c220ad7430329e2ec454cc310bc6d8de2c7e2b93c18a0f6"
name = "github.com/modern-go/reflect2"
packages = ["."]
pruneopts = "NUT"
revision = "1df9eeb2bb81f327b96228865c5687bc2194af3f"
version = "1.0.0"
revision = "4b7aa43c6742a2c18fdef89dd197aaae7dac7ccd"
version = "1.0.1"
[[projects]]
branch = "master"

View File

@ -49,17 +49,10 @@
name = "k8s.io/apiextensions-apiserver"
version = "kubernetes-1.12.0"
# vendor/k8s.io/apimachinery/pkg/runtime/serializer/json/json.go:104:16:
# unknown field 'CaseSensitive' in struct literal of type jsoniter.Config
# k8s.io/client-go v9.0 uses f2b4162afba35581b6d4a50d3b8f34e33c144682 (released in v1.1.4)
[[override]]
name = "github.com/json-iterator/go"
revision = "f2b4162afba35581b6d4a50d3b8f34e33c144682"
# vendor/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go:300:25:
# cannot call non-function spt.Token (type adal.Token)
[[override]]
name = "github.com/Azure/go-autorest"
revision = "1ff28809256a84bb6966640ff3d0371af82ccba4"
version = "~1.1.4"
#
# Cloud provider packages
@ -72,6 +65,11 @@
name = "github.com/Azure/azure-sdk-for-go"
version = "~11.3.0-beta"
# k8s.io/client-go v9.0 uses bca49d5b51a50dc5bb17bbf6204c711c6dbded06 (v10.14.0)
[[constraint]]
name = "github.com/Azure/go-autorest"
version = "~10.14.0"
[[constraint]]
name = "cloud.google.com/go"
version = "0.11.0"

View File

@ -26,10 +26,10 @@ const (
// OAuthConfig represents the endpoints needed
// in OAuth operations
type OAuthConfig struct {
AuthorityEndpoint url.URL
AuthorizeEndpoint url.URL
TokenEndpoint url.URL
DeviceCodeEndpoint url.URL
AuthorityEndpoint url.URL `json:"authorityEndpoint"`
AuthorizeEndpoint url.URL `json:"authorizeEndpoint"`
TokenEndpoint url.URL `json:"tokenEndpoint"`
DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"`
}
// IsZero returns true if the OAuthConfig object is zero-initialized.

View File

@ -15,14 +15,18 @@ package adal
// limitations under the License.
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
"strconv"
@ -31,6 +35,7 @@ import (
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/version"
"github.com/dgrijalva/jwt-go"
)
@ -57,6 +62,9 @@ const (
// msiEndpoint is the well known endpoint for getting MSI authentications tokens
msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
// the default number of attempts to refresh an MSI authentication token
defaultMaxMSIRefreshAttempts = 5
)
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
@ -77,6 +85,13 @@ type Refresher interface {
EnsureFresh() error
}
// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
RefreshWithContext(ctx context.Context) error
RefreshExchangeWithContext(ctx context.Context, resource string) error
EnsureFreshWithContext(ctx context.Context) error
}
// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error
@ -127,6 +142,12 @@ func (t *Token) OAuthToken() string {
return t.AccessToken
}
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}
// ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct {
@ -138,15 +159,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr
return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
}
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
// MarshalJSON implements the json.Marshaler interface.
func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalNoSecret",
})
}
// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct {
ClientSecret string
ClientSecret string `json:"value"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
@ -156,49 +181,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalTokenSecret",
Value: tokenSecret.ClientSecret,
})
}
// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct {
Certificate *x509.Certificate
PrivateKey *rsa.PrivateKey
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string
Password string
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string
AuthorizationCode string
RedirectURI string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
hasher := sha1.New()
@ -219,9 +219,9 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo
token := jwt.New(jwt.SigningMethodRS256)
token.Header["x5t"] = thumbprint
token.Claims = jwt.MapClaims{
"aud": spt.oauthConfig.TokenEndpoint.String(),
"iss": spt.clientID,
"sub": spt.clientID,
"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
"iss": spt.inner.ClientID,
"sub": spt.inner.ClientID,
"jti": base64.URLEncoding.EncodeToString(jti),
"nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24).Unix(),
@ -244,19 +244,151 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string `json:"username"`
Password string `json:"password"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalUsernamePasswordSecret",
Username: secret.Username,
Password: secret.Password,
})
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string `json:"value"`
AuthorizationCode string `json:"authCode"`
RedirectURI string `json:"redirect"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
AuthCode string `json:"authCode"`
Redirect string `json:"redirect"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalAuthorizationCodeSecret",
Value: secret.ClientSecret,
AuthCode: secret.AuthorizationCode,
Redirect: secret.RedirectURI,
})
}
// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
token Token
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
refreshLock *sync.RWMutex
refreshWithin time.Duration
sender Sender
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
// MarshalTokenJSON returns the marshalled inner token.
func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
return json.Marshal(spt.inner.Token)
}
// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
spt.refreshCallbacks = callbacks
}
// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
// need to determine the token type
raw := map[string]interface{}{}
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
secret := raw["secret"].(map[string]interface{})
switch secret["type"] {
case "ServicePrincipalNoSecret":
spt.inner.Secret = &ServicePrincipalNoSecret{}
case "ServicePrincipalTokenSecret":
spt.inner.Secret = &ServicePrincipalTokenSecret{}
case "ServicePrincipalCertificateSecret":
return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
case "ServicePrincipalMSISecret":
return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
case "ServicePrincipalUsernamePasswordSecret":
spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
case "ServicePrincipalAuthorizationCodeSecret":
spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
default:
return fmt.Errorf("unrecognized token type '%s'", secret["type"])
}
err = json.Unmarshal(data, &spt.inner)
if err != nil {
return err
}
spt.refreshLock = &sync.RWMutex{}
spt.sender = &http.Client{}
return nil
}
// internal type used for marshalling/unmarshalling
type servicePrincipalToken struct {
Token Token `json:"token"`
Secret ServicePrincipalSecret `json:"secret"`
OauthConfig OAuthConfig `json:"oauth"`
ClientID string `json:"clientID"`
Resource string `json:"resource"`
AutoRefresh bool `json:"autoRefresh"`
RefreshWithin time.Duration `json:"refreshWithin"`
}
func validateOAuthConfig(oac OAuthConfig) error {
@ -281,13 +413,15 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
spt := &ServicePrincipalToken{
oauthConfig: oauthConfig,
secret: secret,
clientID: id,
resource: resource,
autoRefresh: true,
inner: servicePrincipalToken{
OauthConfig: oauthConfig,
Secret: secret,
ClientID: id,
Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
},
refreshLock: &sync.RWMutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
}
@ -318,7 +452,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
return nil, err
}
spt.token = token
spt.inner.Token = token
return spt, nil
}
// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
secret,
callbacks...)
if err != nil {
return nil, err
}
spt.inner.Token = token
return spt, nil
}
@ -486,20 +652,23 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
msiEndpointURL.RawQuery = v.Encode()
spt := &ServicePrincipalToken{
oauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
inner: servicePrincipalToken{
OauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
},
Secret: &ServicePrincipalMSISecret{},
Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
},
secret: &ServicePrincipalMSISecret{},
resource: resource,
autoRefresh: true,
refreshLock: &sync.RWMutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
refreshLock: &sync.RWMutex{},
sender: &http.Client{},
refreshCallbacks: callbacks,
MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
}
if userAssignedID != nil {
spt.clientID = *userAssignedID
spt.inner.ClientID = *userAssignedID
}
return spt, nil
@ -528,12 +697,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) {
return spt.EnsureFreshWithContext(context.Background())
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
// take the write lock then check to see if the token was already refreshed
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
if spt.token.WillExpireIn(spt.refreshWithin) {
return spt.refreshInternal(spt.resource)
if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
return spt.refreshInternal(ctx, spt.inner.Resource)
}
}
return nil
@ -543,7 +718,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error {
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
if spt.refreshCallbacks != nil {
for _, callback := range spt.refreshCallbacks {
err := callback(spt.token)
err := callback(spt.inner.Token)
if err != nil {
return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
}
@ -555,21 +730,33 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
// Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.RefreshWithContext(context.Background())
}
// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(spt.resource)
return spt.refreshInternal(ctx, spt.inner.Resource)
}
// RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.RefreshExchangeWithContext(context.Background(), resource)
}
// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(resource)
return spt.refreshInternal(ctx, resource)
}
func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.secret.(type) {
switch spt.inner.Secret.(type) {
case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret:
@ -587,23 +774,32 @@ func isIMDS(u url.URL) bool {
return u.Host == imds.Host && u.Path == imds.Path
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil)
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}
if !isIMDS(spt.oauthConfig.TokenEndpoint) {
req.Header.Add("User-Agent", version.UserAgent())
req = req.WithContext(ctx)
if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
v := url.Values{}
v.Set("client_id", spt.clientID)
v.Set("client_id", spt.inner.ClientID)
v.Set("resource", resource)
if spt.token.RefreshToken != "" {
if spt.inner.Token.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.token.RefreshToken)
v.Set("refresh_token", spt.inner.Token.RefreshToken)
// web apps must specify client_secret when refreshing tokens
// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
} else {
v.Set("grant_type", spt.getGrantType())
err := spt.secret.SetAuthenticationValues(spt, &v)
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
@ -616,19 +812,19 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
req.Body = body
}
if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok {
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
req.Method = http.MethodGet
req.Header.Set(metadataHeader, "true")
}
var resp *http.Response
if isIMDS(spt.oauthConfig.TokenEndpoint) {
resp, err = retry(spt.sender, req)
if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
} else {
resp, err = spt.sender.Do(req)
}
if err != nil {
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
}
defer resp.Body.Close()
@ -636,11 +832,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK {
if err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
}
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
}
// for the following error cases don't return a TokenRefreshError. the operation succeeded
// but some transient failure happened during deserialization. by returning a generic error
// the retry logic will kick in (we don't retry on TokenRefreshError).
if err != nil {
return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
}
@ -653,12 +853,14 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
}
spt.token = token
spt.inner.Token = token
return spt.InvokeRefreshCallbacks(token)
}
func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {
// retry logic specific to retrieving a token from the IMDS endpoint
func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
// copied from client.go due to circular dependency
retries := []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
@ -667,8 +869,10 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}
// Extra retry status codes requered
retries = append(retries, http.StatusNotFound,
// extra retry status codes specific to IMDS
retries = append(retries,
http.StatusNotFound,
http.StatusGone,
// all remaining 5xx
http.StatusNotImplemented,
http.StatusHTTPVersionNotSupported,
@ -678,34 +882,52 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {
http.StatusNotExtended,
http.StatusNetworkAuthenticationRequired)
// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
const maxDelay time.Duration = 60 * time.Second
attempt := 0
maxAttempts := 5
delay := time.Duration(0)
for attempt < maxAttempts {
resp, err = sender.Do(req)
if err != nil {
// retry on temporary network errors, e.g. transient network failures.
// if we don't receive a response then assume we can't connect to the
// endpoint so we're likely not running on an Azure VM so don't retry.
if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
return
}
if resp.StatusCode == http.StatusOK {
return
// perform exponential backoff with a cap.
// must increment attempt before calculating delay.
attempt++
// the base value of 2 is the "delta backoff" as specified in the guidance doc
delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
if delay > maxDelay {
delay = maxDelay
}
if containsInt(retries, resp.StatusCode) {
delayed := false
if resp.StatusCode == http.StatusTooManyRequests {
delayed = delay(resp, req.Cancel)
}
if !delayed {
time.Sleep(time.Second)
attempt++
}
} else {
select {
case <-time.After(delay):
// intentionally left blank
case <-req.Context().Done():
err = req.Context().Err()
return
}
}
return
}
// returns true if the specified error is a temporary network error or false if it's not.
// if the error doesn't implement the net.Error interface the return value is true.
func isTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}
// returns true if slice ints contains the value n
func containsInt(ints []int, n int) bool {
for _, i := range ints {
if i == n {
@ -715,31 +937,15 @@ func containsInt(ints []int, n int) bool {
return false
}
func delay(resp *http.Response, cancel <-chan struct{}) bool {
if resp == nil {
return false
}
retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
select {
case <-time.After(time.Duration(retryAfter) * time.Second):
return true
case <-cancel:
return false
}
}
return false
}
// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
spt.autoRefresh = autoRefresh
spt.inner.AutoRefresh = autoRefresh
}
// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
spt.refreshWithin = d
spt.inner.RefreshWithin = d
return
}
@ -751,12 +957,12 @@ func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
func (spt *ServicePrincipalToken) OAuthToken() string {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.token.OAuthToken()
return spt.inner.Token.OAuthToken()
}
// Token returns a copy of the current token.
func (spt *ServicePrincipalToken) Token() Token {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.token
return spt.inner.Token
}

View File

@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
refresher, ok := ba.tokenProvider.(adal.Refresher)
if ok {
err := refresher.EnsureFresh()
if err != nil {
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
// the ordering is important here, prefer RefresherWithContext if available
if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok {
err = refresher.EnsureFreshWithContext(r.Context())
} else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok {
err = refresher.EnsureFresh()
}
if err != nil {
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
}
return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken())))
}

File diff suppressed because it is too large Load Diff

View File

@ -44,11 +44,12 @@ const (
// ServiceError encapsulates the error response from an Azure service.
// It adhears to the OData v4 specification for error responses.
type ServiceError struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
func (se ServiceError) Error() string {
@ -74,6 +75,14 @@ func (se ServiceError) Error() string {
result += fmt.Sprintf(" InnerError=%v", string(d))
}
if se.AdditionalInfo != nil {
d, err := json.Marshal(se.AdditionalInfo)
if err != nil {
result += fmt.Sprintf(" AdditionalInfo=%v", se.AdditionalInfo)
}
result += fmt.Sprintf(" AdditionalInfo=%v", string(d))
}
return result
}
@ -86,44 +95,47 @@ func (se *ServiceError) UnmarshalJSON(b []byte) error {
// http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091
type serviceError1 struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
type serviceError2 struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
se1 := serviceError1{}
err := json.Unmarshal(b, &se1)
if err == nil {
se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError)
se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError, se1.AdditionalInfo)
return nil
}
se2 := serviceError2{}
err = json.Unmarshal(b, &se2)
if err == nil {
se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError)
se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError, se2.AdditionalInfo)
se.Details = append(se.Details, se2.Details)
return nil
}
return err
}
func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}) {
func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}, additional []map[string]interface{}) {
se.Code = code
se.Message = message
se.Target = target
se.Details = details
se.InnerError = inner
se.AdditionalInfo = additional
}
// RequestError describes an error response returned by Azure service.
@ -279,16 +291,29 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator {
resp.Body = ioutil.NopCloser(&b)
if decodeErr != nil {
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr)
} else if e.ServiceError == nil {
}
if e.ServiceError == nil {
// Check if error is unwrapped ServiceError
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" {
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil {
return err
}
}
if e.ServiceError.Message == "" {
// if we're here it means the returned error wasn't OData v4 compliant.
// try to unmarshal the body as raw JSON in hopes of getting something.
rawBody := map[string]interface{}{}
if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil {
return err
}
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if len(rawBody) > 0 {
e.ServiceError.Details = []map[string]interface{}{rawBody}
}
}
e.Response = resp
e.RequestID = ExtractRequestID(resp)
if e.StatusCode == nil {
e.StatusCode = resp.StatusCode

View File

@ -64,7 +64,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
}
}
}
return resp, fmt.Errorf("failed request: %s", err)
return resp, err
})
}
}

View File

@ -22,8 +22,9 @@ import (
"log"
"net/http"
"net/http/cookiejar"
"runtime"
"time"
"github.com/Azure/go-autorest/version"
)
const (
@ -41,15 +42,6 @@ const (
)
var (
// defaultUserAgent builds a string containing the Go version, system archityecture and OS,
// and the go-autorest version.
defaultUserAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
runtime.Version(),
runtime.GOARCH,
runtime.GOOS,
Version(),
)
// StatusCodesForRetry are a defined group of status code for which the client will retry
StatusCodesForRetry = []int{
http.StatusRequestTimeout, // 408
@ -179,7 +171,7 @@ func NewClientWithUserAgent(ua string) Client {
PollingDuration: DefaultPollingDuration,
RetryAttempts: DefaultRetryAttempts,
RetryDuration: DefaultRetryDuration,
UserAgent: defaultUserAgent,
UserAgent: version.UserAgent(),
}
c.Sender = c.sender()
c.AddToUserAgent(ua)

View File

@ -223,6 +223,10 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
return resp, err
}
resp, err = s.Do(rr.Request())
// if the error isn't temporary don't bother retrying
if err != nil && !IsTemporaryNetworkError(err) {
return nil, err
}
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {

View File

@ -20,6 +20,7 @@ import (
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
@ -216,3 +217,12 @@ func IsTokenRefreshError(err error) bool {
}
return false
}
// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false
// if it's not. If the error doesn't implement the net.Error interface the return value is true.
func IsTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}

View File

@ -136,29 +136,29 @@ func validatePtr(x reflect.Value, v Constraint) error {
func validateInt(x reflect.Value, v Constraint) error {
i := x.Int()
r, ok := v.Rule.(int)
r, ok := toInt64(v.Rule)
if !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
}
switch v.Name {
case MultipleOf:
if i%int64(r) != 0 {
if i%r != 0 {
return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r))
}
case ExclusiveMinimum:
if i <= int64(r) {
if i <= r {
return createError(x, v, fmt.Sprintf("value must be greater than %v", r))
}
case ExclusiveMaximum:
if i >= int64(r) {
if i >= r {
return createError(x, v, fmt.Sprintf("value must be less than %v", r))
}
case InclusiveMinimum:
if i < int64(r) {
if i < r {
return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r))
}
case InclusiveMaximum:
if i > int64(r) {
if i > r {
return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r))
}
default:
@ -388,6 +388,17 @@ func createError(x reflect.Value, v Constraint, err string) error {
v.Target, v.Name, getInterfaceValue(x), err)
}
func toInt64(v interface{}) (int64, bool) {
if i64, ok := v.(int64); ok {
return i64, true
}
// older generators emit max constants as int, so if int64 fails fall back to int
if i32, ok := v.(int); ok {
return int64(i32), true
}
return 0, false
}
// NewErrorWithValidationError appends package type and method name in
// validation error.
//

View File

@ -1,5 +1,7 @@
package autorest
import "github.com/Azure/go-autorest/version"
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
@ -16,5 +18,5 @@ package autorest
// Version returns the semantic version (see http://semver.org).
func Version() string {
return "v10.5.0"
return version.Number
}

37
vendor/github.com/Azure/go-autorest/version/version.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
package version
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"runtime"
)
// Number contains the semantic version of this SDK.
const Number = "v10.14.0"
var (
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
runtime.Version(),
runtime.GOARCH,
runtime.GOOS,
Number,
)
)
// UserAgent returns a string containing the Go version, system archityecture and OS, and the go-autorest version.
func UserAgent() string {
return userAgent
}

View File

@ -81,10 +81,12 @@ func (adapter *Decoder) More() bool {
if iter.Error != nil {
return false
}
if iter.head != iter.tail {
return true
c := iter.nextToken()
if c == 0 {
return false
}
return iter.loadMore()
iter.unreadByte()
return c != ']' && c != '}'
}
// Buffered remaining buffer
@ -98,7 +100,7 @@ func (adapter *Decoder) Buffered() io.Reader {
func (adapter *Decoder) UseNumber() {
cfg := adapter.iter.cfg.configBeforeFrozen
cfg.UseNumber = true
adapter.iter.cfg = cfg.frozeWithCacheReuse()
adapter.iter.cfg = cfg.frozeWithCacheReuse(adapter.iter.cfg.extraExtensions)
}
// DisallowUnknownFields causes the Decoder to return an error when the destination
@ -107,7 +109,7 @@ func (adapter *Decoder) UseNumber() {
func (adapter *Decoder) DisallowUnknownFields() {
cfg := adapter.iter.cfg.configBeforeFrozen
cfg.DisallowUnknownFields = true
adapter.iter.cfg = cfg.frozeWithCacheReuse()
adapter.iter.cfg = cfg.frozeWithCacheReuse(adapter.iter.cfg.extraExtensions)
}
// NewEncoder same as json.NewEncoder
@ -132,14 +134,14 @@ func (adapter *Encoder) Encode(val interface{}) error {
func (adapter *Encoder) SetIndent(prefix, indent string) {
config := adapter.stream.cfg.configBeforeFrozen
config.IndentionStep = len(indent)
adapter.stream.cfg = config.frozeWithCacheReuse()
adapter.stream.cfg = config.frozeWithCacheReuse(adapter.stream.cfg.extraExtensions)
}
// SetEscapeHTML escape html by default, set to false to disable
func (adapter *Encoder) SetEscapeHTML(escapeHTML bool) {
config := adapter.stream.cfg.configBeforeFrozen
config.EscapeHTML = escapeHTML
adapter.stream.cfg = config.frozeWithCacheReuse()
adapter.stream.cfg = config.frozeWithCacheReuse(adapter.stream.cfg.extraExtensions)
}
// Valid reports whether data is a valid JSON encoding.

View File

@ -74,7 +74,9 @@ type frozenConfig struct {
disallowUnknownFields bool
decoderCache *concurrent.Map
encoderCache *concurrent.Map
extensions []Extension
encoderExtension Extension
decoderExtension Extension
extraExtensions []Extension
streamPool *sync.Pool
iteratorPool *sync.Pool
caseSensitive bool
@ -158,22 +160,21 @@ func (cfg Config) Froze() API {
if cfg.ValidateJsonRawMessage {
api.validateJsonRawMessage(encoderExtension)
}
if len(encoderExtension) > 0 {
api.extensions = append(api.extensions, encoderExtension)
}
if len(decoderExtension) > 0 {
api.extensions = append(api.extensions, decoderExtension)
}
api.encoderExtension = encoderExtension
api.decoderExtension = decoderExtension
api.configBeforeFrozen = cfg
return api
}
func (cfg Config) frozeWithCacheReuse() *frozenConfig {
func (cfg Config) frozeWithCacheReuse(extraExtensions []Extension) *frozenConfig {
api := getFrozenConfigFromCache(cfg)
if api != nil {
return api
}
api = cfg.Froze().(*frozenConfig)
for _, extension := range extraExtensions {
api.RegisterExtension(extension)
}
addFrozenConfigToCache(cfg, api)
return api
}
@ -190,7 +191,7 @@ func (cfg *frozenConfig) validateJsonRawMessage(extension EncoderExtension) {
stream.WriteRaw(string(rawMessage))
}
}, func(ptr unsafe.Pointer) bool {
return false
return len(*((*json.RawMessage)(ptr))) == 0
}}
extension[reflect2.TypeOfPtr((*json.RawMessage)(nil)).Elem()] = encoder
extension[reflect2.TypeOfPtr((*RawMessage)(nil)).Elem()] = encoder
@ -219,7 +220,9 @@ func (cfg *frozenConfig) getTagKey() string {
}
func (cfg *frozenConfig) RegisterExtension(extension Extension) {
cfg.extensions = append(cfg.extensions, extension)
cfg.extraExtensions = append(cfg.extraExtensions, extension)
copied := cfg.configBeforeFrozen
cfg.configBeforeFrozen = copied
}
type lossyFloat32Encoder struct {
@ -314,7 +317,7 @@ func (cfg *frozenConfig) MarshalIndent(v interface{}, prefix, indent string) ([]
}
newCfg := cfg.configBeforeFrozen
newCfg.IndentionStep = len(indent)
return newCfg.frozeWithCacheReuse().Marshal(v)
return newCfg.frozeWithCacheReuse(cfg.extraExtensions).Marshal(v)
}
func (cfg *frozenConfig) UnmarshalFromString(str string, v interface{}) error {

View File

@ -2,7 +2,7 @@ package jsoniter
import (
"fmt"
"unicode"
"strings"
)
// ReadObject read one field from object.
@ -96,13 +96,12 @@ func (iter *Iterator) readFieldHash() int64 {
}
func calcHash(str string, caseSensitive bool) int64 {
if !caseSensitive {
str = strings.ToLower(str)
}
hash := int64(0x811c9dc5)
for _, b := range str {
if caseSensitive {
hash ^= int64(b)
} else {
hash ^= int64(unicode.ToLower(b))
}
for _, b := range []byte(str) {
hash ^= int64(b)
hash *= 0x1000193
}
return int64(hash)

View File

@ -120,7 +120,8 @@ func decoderOfType(ctx *ctx, typ reflect2.Type) ValDecoder {
for _, extension := range extensions {
decoder = extension.DecorateDecoder(typ, decoder)
}
for _, extension := range ctx.extensions {
decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
for _, extension := range ctx.extraExtensions {
decoder = extension.DecorateDecoder(typ, decoder)
}
return decoder
@ -222,7 +223,8 @@ func encoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder {
for _, extension := range extensions {
encoder = extension.DecorateEncoder(typ, encoder)
}
for _, extension := range ctx.extensions {
encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
for _, extension := range ctx.extraExtensions {
encoder = extension.DecorateEncoder(typ, encoder)
}
return encoder

View File

@ -246,7 +246,8 @@ func getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
for _, extension := range extensions {
decoder = extension.DecorateDecoder(typ, decoder)
}
for _, extension := range ctx.extensions {
decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
for _, extension := range ctx.extraExtensions {
decoder = extension.DecorateDecoder(typ, decoder)
}
}
@ -259,14 +260,18 @@ func _getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
return decoder
}
}
for _, extension := range ctx.extensions {
decoder := ctx.decoderExtension.CreateDecoder(typ)
if decoder != nil {
return decoder
}
for _, extension := range ctx.extraExtensions {
decoder := extension.CreateDecoder(typ)
if decoder != nil {
return decoder
}
}
typeName := typ.String()
decoder := typeDecoders[typeName]
decoder = typeDecoders[typeName]
if decoder != nil {
return decoder
}
@ -286,7 +291,8 @@ func getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
for _, extension := range extensions {
encoder = extension.DecorateEncoder(typ, encoder)
}
for _, extension := range ctx.extensions {
encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
for _, extension := range ctx.extraExtensions {
encoder = extension.DecorateEncoder(typ, encoder)
}
}
@ -300,14 +306,18 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
return encoder
}
}
for _, extension := range ctx.extensions {
encoder := ctx.encoderExtension.CreateEncoder(typ)
if encoder != nil {
return encoder
}
for _, extension := range ctx.extraExtensions {
encoder := extension.CreateEncoder(typ)
if encoder != nil {
return encoder
}
}
typeName := typ.String()
encoder := typeEncoders[typeName]
encoder = typeEncoders[typeName]
if encoder != nil {
return encoder
}
@ -393,7 +403,9 @@ func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, em
for _, extension := range extensions {
extension.UpdateStructDescriptor(structDescriptor)
}
for _, extension := range ctx.extensions {
ctx.encoderExtension.UpdateStructDescriptor(structDescriptor)
ctx.decoderExtension.UpdateStructDescriptor(structDescriptor)
for _, extension := range ctx.extraExtensions {
extension.UpdateStructDescriptor(structDescriptor)
}
processTags(structDescriptor, ctx.frozenConfig)

View File

@ -39,7 +39,11 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
}
func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
for _, extension := range ctx.extensions {
decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ)
if decoder != nil {
return decoder
}
for _, extension := range ctx.extraExtensions {
decoder := extension.CreateMapKeyDecoder(typ)
if decoder != nil {
return decoder
@ -77,7 +81,11 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
}
func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
for _, extension := range ctx.extensions {
encoder := ctx.encoderExtension.CreateMapKeyEncoder(typ)
if encoder != nil {
return encoder
}
for _, extension := range ctx.extraExtensions {
encoder := extension.CreateMapKeyEncoder(typ)
if encoder != nil {
return encoder

View File

@ -6,4 +6,4 @@ import "unsafe"
func resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer {
return nil
}
}

View File

@ -1,9 +1,9 @@
package reflect2
import (
"github.com/modern-go/concurrent"
"reflect"
"unsafe"
"github.com/modern-go/concurrent"
)
type Type interface {
@ -136,7 +136,7 @@ type frozenConfig struct {
func (cfg Config) Froze() *frozenConfig {
return &frozenConfig{
useSafeImplementation: cfg.UseSafeImplementation,
cache: concurrent.NewMap(),
cache: concurrent.NewMap(),
}
}
@ -150,6 +150,9 @@ func (cfg *frozenConfig) TypeOf(obj interface{}) Type {
}
func (cfg *frozenConfig) Type2(type1 reflect.Type) Type {
if type1 == nil {
return nil
}
cacheKey := uintptr(unpackEFace(type1).data)
typeObj, found := cfg.cache.Load(cacheKey)
if found {
@ -213,6 +216,9 @@ func TypeOfPtr(obj interface{}) PtrType {
}
func Type2(type1 reflect.Type) Type {
if type1 == nil {
return nil
}
return ConfigUnsafe.Type2(type1)
}
@ -279,4 +285,14 @@ func likePtrType(typ reflect.Type) bool {
func NoEscape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
}
func UnsafeCastString(str string) []byte {
stringHeader := (*reflect.StringHeader)(unsafe.Pointer(&str))
sliceHeader := &reflect.SliceHeader{
Data: stringHeader.Data,
Cap: stringHeader.Len,
Len: stringHeader.Len,
}
return *(*[]byte)(unsafe.Pointer(sliceHeader))
}

View File

@ -11,20 +11,20 @@ func DefaultTypeOfKind(kind reflect.Kind) Type {
}
var kindTypes = map[reflect.Kind]Type{
reflect.Bool: TypeOf(true),
reflect.Uint8: TypeOf(uint8(0)),
reflect.Int8: TypeOf(int8(0)),
reflect.Uint16: TypeOf(uint16(0)),
reflect.Int16: TypeOf(int16(0)),
reflect.Uint32: TypeOf(uint32(0)),
reflect.Int32: TypeOf(int32(0)),
reflect.Uint64: TypeOf(uint64(0)),
reflect.Int64: TypeOf(int64(0)),
reflect.Uint: TypeOf(uint(0)),
reflect.Int: TypeOf(int(0)),
reflect.Float32: TypeOf(float32(0)),
reflect.Float64: TypeOf(float64(0)),
reflect.Uintptr: TypeOf(uintptr(0)),
reflect.String: TypeOf(""),
reflect.Bool: TypeOf(true),
reflect.Uint8: TypeOf(uint8(0)),
reflect.Int8: TypeOf(int8(0)),
reflect.Uint16: TypeOf(uint16(0)),
reflect.Int16: TypeOf(int16(0)),
reflect.Uint32: TypeOf(uint32(0)),
reflect.Int32: TypeOf(int32(0)),
reflect.Uint64: TypeOf(uint64(0)),
reflect.Int64: TypeOf(int64(0)),
reflect.Uint: TypeOf(uint(0)),
reflect.Int: TypeOf(int(0)),
reflect.Float32: TypeOf(float32(0)),
reflect.Float64: TypeOf(float64(0)),
reflect.Uintptr: TypeOf(uintptr(0)),
reflect.String: TypeOf(""),
reflect.UnsafePointer: TypeOf(unsafe.Pointer(nil)),
}
}

View File

@ -55,4 +55,4 @@ func (field *safeField) Get(obj interface{}) interface{} {
func (field *safeField) UnsafeGet(obj unsafe.Pointer) unsafe.Pointer {
panic("does not support unsafe operation")
}
}

View File

@ -76,8 +76,8 @@ func (type2 *safeMapType) UnsafeIterate(obj unsafe.Pointer) MapIterator {
}
type safeMapIterator struct {
i int
m reflect.Value
i int
m reflect.Value
keys []reflect.Value
}
@ -98,4 +98,4 @@ func (iter *safeMapIterator) Next() (interface{}, interface{}) {
func (iter *safeMapIterator) UnsafeNext() (unsafe.Pointer, unsafe.Pointer) {
panic("does not support unsafe operation")
}
}

View File

@ -89,4 +89,4 @@ func (type2 *safeSliceType) Cap(obj interface{}) int {
func (type2 *safeSliceType) UnsafeCap(ptr unsafe.Pointer) int {
panic("does not support unsafe operation")
}
}

View File

@ -10,4 +10,20 @@ func (type2 *safeStructType) FieldByName(name string) StructField {
panic("field " + name + " not found")
}
return &safeField{StructField: field}
}
}
func (type2 *safeStructType) Field(i int) StructField {
return &safeField{StructField: type2.Type.Field(i)}
}
func (type2 *safeStructType) FieldByIndex(index []int) StructField {
return &safeField{StructField: type2.Type.FieldByIndex(index)}
}
func (type2 *safeStructType) FieldByNameFunc(match func(string) bool) StructField {
field, found := type2.Type.FieldByNameFunc(match)
if !found {
panic("field match condition not found in " + type2.Type.String())
}
return &safeField{StructField: field}
}

View File

@ -75,4 +75,4 @@ func (type2 *safeType) UnsafeSet(ptr unsafe.Pointer, val unsafe.Pointer) {
func (type2 *safeType) AssignableTo(anotherType Type) bool {
return type2.Type1().AssignableTo(anotherType.Type1())
}
}

View File

@ -1,10 +1,10 @@
package reflect2
import (
"unsafe"
"reflect"
"runtime"
"strings"
"unsafe"
)
// typelinks1 for 1.5 ~ 1.6
@ -16,6 +16,7 @@ func typelinks1() [][]unsafe.Pointer
func typelinks2() (sections []unsafe.Pointer, offset [][]int32)
var types = map[string]reflect.Type{}
var packages = map[string]map[string]reflect.Type{}
func init() {
ver := runtime.Version()
@ -36,11 +37,25 @@ func loadGo15Types() {
(*emptyInterface)(unsafe.Pointer(&obj)).word = typePtr
typ := obj.(reflect.Type)
if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct {
types[typ.Elem().String()] = typ.Elem()
loadedType := typ.Elem()
pkgTypes := packages[loadedType.PkgPath()]
if pkgTypes == nil {
pkgTypes = map[string]reflect.Type{}
packages[loadedType.PkgPath()] = pkgTypes
}
types[loadedType.String()] = loadedType
pkgTypes[loadedType.Name()] = loadedType
}
if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Ptr &&
typ.Elem().Elem().Kind() == reflect.Struct {
types[typ.Elem().Elem().String()] = typ.Elem().Elem()
loadedType := typ.Elem().Elem()
pkgTypes := packages[loadedType.PkgPath()]
if pkgTypes == nil {
pkgTypes = map[string]reflect.Type{}
packages[loadedType.PkgPath()] = pkgTypes
}
types[loadedType.String()] = loadedType
pkgTypes[loadedType.Name()] = loadedType
}
}
}
@ -55,7 +70,14 @@ func loadGo17Types() {
(*emptyInterface)(unsafe.Pointer(&obj)).word = resolveTypeOff(unsafe.Pointer(rodata), off)
typ := obj.(reflect.Type)
if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct {
types[typ.Elem().String()] = typ.Elem()
loadedType := typ.Elem()
pkgTypes := packages[loadedType.PkgPath()]
if pkgTypes == nil {
pkgTypes = map[string]reflect.Type{}
packages[loadedType.PkgPath()] = pkgTypes
}
types[loadedType.String()] = loadedType
pkgTypes[loadedType.Name()] = loadedType
}
}
}
@ -70,3 +92,12 @@ type emptyInterface struct {
func TypeByName(typeName string) Type {
return Type2(types[typeName])
}
// TypeByPackageName return the type by its package and name
func TypeByPackageName(pkgPath string, name string) Type {
pkgTypes := packages[pkgPath]
if pkgTypes == nil {
return nil
}
return Type2(pkgTypes[name])
}

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
type UnsafeArrayType struct {

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
type eface struct {
@ -56,4 +56,4 @@ func (type2 *UnsafeEFaceType) Indirect(obj interface{}) interface{} {
func (type2 *UnsafeEFaceType) UnsafeIndirect(ptr unsafe.Pointer) interface{} {
return *(*interface{})(ptr)
}
}

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
type iface struct {
@ -61,4 +61,4 @@ func (type2 *UnsafeIFaceType) UnsafeIsNil(ptr unsafe.Pointer) bool {
return true
}
return false
}
}

View File

@ -67,4 +67,4 @@ func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer {
// the benefit is to surface this assumption at the call site.)
func arrayAt(p unsafe.Pointer, i int, eltSize uintptr, whySafe string) unsafe.Pointer {
return add(p, uintptr(i)*eltSize, "i < len")
}
}

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
type UnsafePtrType struct {

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
// sliceHeader is a safe version of SliceHeader used within this package.

View File

@ -56,4 +56,4 @@ func (type2 *UnsafeStructType) FieldByNameFunc(match func(string) bool) StructFi
panic("field match condition not found in " + type2.Type.String())
}
return newUnsafeStructField(type2, structField)
}
}

View File

@ -1,8 +1,8 @@
package reflect2
import (
"unsafe"
"reflect"
"unsafe"
)
type unsafeType struct {
@ -15,7 +15,7 @@ func newUnsafeType(cfg *frozenConfig, type1 reflect.Type) *unsafeType {
return &unsafeType{
safeType: safeType{
Type: type1,
cfg: cfg,
cfg: cfg,
},
rtype: unpackEFace(type1).data,
ptrRType: unpackEFace(reflect.PtrTo(type1)).data,