@@ -7,9 +7,11 @@ import (
77 "context"
88 "encoding/base64"
99 "fmt"
10+ "net/http"
1011 "strings"
1112
1213 "github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud"
14+ "github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds"
1315 "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
1416 "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
1517 "github.com/Azure/go-autorest/autorest/adal"
@@ -30,15 +32,11 @@ const (
3032)
3133
3234// extractAccessTokenFunc extracts an oauth access token from the specified service principal token after a refresh, fake implementations given in unit tests.
33- type extractAccessTokenFunc func (token * adal.ServicePrincipalToken ) (string , error )
35+ type extractAccessTokenFunc func (token * adal.ServicePrincipalToken , isMSI bool ) (string , error )
3436
35- func extractAccessToken (token * adal.ServicePrincipalToken ) (string , error ) {
37+ func extractAccessToken (token * adal.ServicePrincipalToken , isMSI bool ) (string , error ) {
3638 if err := token .Refresh (); err != nil {
37- return "" , & bootstrapError {
38- errorType : ErrorTypeGetAccessTokenFailure ,
39- retryable : true ,
40- inner : fmt .Errorf ("obtaining fresh access token: %w" , err ),
41- }
39+ return "" , tokenRefreshErrorToGetAccessTokenFailure (err , isMSI )
4240 }
4341 return token .OAuthToken (), nil
4442}
@@ -56,56 +54,117 @@ func (c *client) getAccessToken(ctx context.Context, userAssignedIdentityID, res
5654
5755 if userAssignedID != "" {
5856 logger .Info ("generating MSI access token" , zap .String ("clientId" , userAssignedID ))
59- token , err := adal .NewServicePrincipalTokenFromManagedIdentity (resource , & adal.ManagedIdentityOptions {
57+ msiToken , err := adal .NewServicePrincipalTokenFromManagedIdentity (resource , & adal.ManagedIdentityOptions {
6058 ClientID : userAssignedID ,
6159 })
6260 if err != nil {
6361 return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("generating MSI access token: %w" , err ))
6462 }
6563 // to avoid falling too deep into exponential backoff implemented by adal, which follows the public retry guidance:
6664 // https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
67- token .MaxMSIRefreshAttempts = maxMSIRefreshAttempts
68- return c .extractAccessTokenFunc (token )
65+ msiToken .MaxMSIRefreshAttempts = maxMSIRefreshAttempts
66+ return c .extractAccessTokenFunc (msiToken , true )
6967 }
7068
7169 if cloudProviderConfig .ClientID == clientIDForMSI {
7270 return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("client ID within cloud provider config indicates usage of a managed identity, though no user-assigned identity ID was provided" ))
7371 }
7472
73+ servicePrincipalToken , err := getServicePrincipalToken (ctx , resource , cloudProviderConfig )
74+ if err != nil {
75+ return "" , err
76+ }
77+
78+ return c .extractAccessTokenFunc (servicePrincipalToken , false )
79+ }
80+
81+ func getServicePrincipalToken (ctx context.Context , resource string , cloudProviderConfig * cloud.ProviderConfig ) (* adal.ServicePrincipalToken , error ) {
82+ logger := log .MustGetLogger (ctx )
83+
84+ secret := maybeB64Decode (cloudProviderConfig .ClientSecret )
85+
7586 env , err := azure .EnvironmentFromName (cloudProviderConfig .CloudName )
7687 if err != nil {
77- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("getting azure environment config for cloud %q: %w" , cloudProviderConfig .CloudName , err ))
88+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("getting azure environment config for cloud %q: %w" , cloudProviderConfig .CloudName , err ))
7889 }
7990 oauthConfig , err := adal .NewOAuthConfig (env .ActiveDirectoryEndpoint , cloudProviderConfig .TenantID )
8091 if err != nil {
81- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("creating oauth config with azure environment: %w" , err ))
92+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("creating oauth config with azure environment: %w" , err ))
8293 }
8394
84- if ! strings .HasPrefix (cloudProviderConfig . ClientSecret , certificateSecretPrefix ) {
85- logger .Info ("generating SPN access token with username and password " , zap .String ("clientId" , cloudProviderConfig .ClientID ))
86- token , err := adal .NewServicePrincipalToken (* oauthConfig , cloudProviderConfig .ClientID , cloudProviderConfig . ClientSecret , resource )
95+ if ! strings .HasPrefix (secret , certificateSecretPrefix ) {
96+ logger .Info ("generating service principal access token with client secret " , zap .String ("clientId" , cloudProviderConfig .ClientID ))
97+ token , err := adal .NewServicePrincipalToken (* oauthConfig , cloudProviderConfig .ClientID , secret , resource )
8798 if err != nil {
88- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("generating SPN access token with username and password : %w" , err ))
99+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("generating service principal access token with client secret : %w" , err ))
89100 }
90- return c . extractAccessTokenFunc ( token )
101+ return token , nil
91102 }
92103
93- logger .Info ("client secret contains certificate data, using certificate to generate SPN access token" , zap .String ("clientId" , cloudProviderConfig .ClientID ))
104+ logger .Info ("client secret contains certificate data, using certificate to generate service principal access token" , zap .String ("clientId" , cloudProviderConfig .ClientID ))
94105
95- certData , err := base64 .StdEncoding .DecodeString (strings .TrimPrefix (cloudProviderConfig . ClientSecret , certificateSecretPrefix ))
106+ certData , err := base64 .StdEncoding .DecodeString (strings .TrimPrefix (secret , certificateSecretPrefix ))
96107 if err != nil {
97- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("b64-decoding certificate data in client secret: %w" , err ))
108+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("b64-decoding certificate data in client secret: %w" , err ))
98109 }
99110 certificate , privateKey , err := adal .DecodePfxCertificateData (certData , "" )
100111 if err != nil {
101- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("decoding pfx certificate data in client secret: %w" , err ))
112+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("decoding pfx certificate data in client secret: %w" , err ))
102113 }
103114
104- logger .Info ("generating SPN access token with certificate" , zap .String ("clientId" , cloudProviderConfig .ClientID ))
115+ logger .Info ("generating service principal access token with certificate" , zap .String ("clientId" , cloudProviderConfig .ClientID ))
105116 token , err := adal .NewServicePrincipalTokenFromCertificate (* oauthConfig , cloudProviderConfig .ClientID , certificate , privateKey , resource )
106117 if err != nil {
107- return "" , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("generating SPN access token with certificate: %w" , err ))
118+ return nil , makeNonRetryableGetAccessTokenFailure (fmt .Errorf ("generating service principal access token with certificate: %w" , err ))
119+ }
120+
121+ return token , nil
122+ }
123+
124+ func maybeB64Decode (str string ) string {
125+ if decoded , err := base64 .StdEncoding .DecodeString (str ); err == nil {
126+ return string (decoded )
127+ }
128+ return str
129+ }
130+
131+ func makeNonRetryableGetAccessTokenFailure (err error ) error {
132+ return & bootstrapError {
133+ errorType : ErrorTypeGetAccessTokenFailure ,
134+ retryable : false ,
135+ inner : err ,
136+ }
137+ }
138+
139+ func tokenRefreshErrorToGetAccessTokenFailure (err error , isMSI bool ) error {
140+ bootstrapErr := & bootstrapError {
141+ errorType : ErrorTypeGetAccessTokenFailure ,
142+ retryable : true , // optimistically consider the error retryable from the start
143+ inner : fmt .Errorf ("obtaining fresh access token: %w" , err ),
144+ }
145+
146+ rerr , ok := err .(adal.TokenRefreshError )
147+ if ! ok {
148+ return bootstrapErr
149+ }
150+
151+ resp := rerr .Response ()
152+ if resp == nil {
153+ return bootstrapErr
154+ }
155+
156+ if ! isMSI {
157+ bootstrapErr .retryable = resp .StatusCode >= http .StatusInternalServerError
158+ return bootstrapErr
159+ }
160+
161+ if resp .StatusCode != http .StatusBadRequest {
162+ bootstrapErr .retryable = imds .IsRetryableHTTPStatusCode (resp .StatusCode )
163+ return bootstrapErr
108164 }
109165
110- return c .extractAccessTokenFunc (token )
166+ // 400s aren't normally retryable, though identity assignment can sometimes take a bit of time to propagate to IMDS,
167+ // so we treat "Identity not found" errors as retryable
168+ bootstrapErr .retryable = strings .Contains (strings .ToLower (err .Error ()), "identity not found" )
169+ return bootstrapErr
111170}
0 commit comments