Skip to content

Commit f18eb62

Browse files
cameronmeissnerCameron Meissner
andauthored
fix(client): correctly handle b64-encoded service principal certificate credential, improve retry logic (#154)
Co-authored-by: Cameron Meissner <cameissner@microsoft.com>
1 parent 770c58d commit f18eb62

16 files changed

Lines changed: 595 additions & 145 deletions

File tree

client/cmd/client/main.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,17 @@ func run(ctx context.Context) int {
116116
case errors.Is(err, context.Canceled):
117117
logger.Error("context was cancelled before bootstrapping could complete")
118118
case errors.Is(err, context.DeadlineExceeded):
119+
err = errors.Unwrap(err)
119120
logger.Error(
120121
"failed to successfully bootstrap before the specified deadline",
121-
zap.Error(errors.Unwrap(err)),
122+
zap.Error(err),
122123
zap.Time("deadline", deadline),
123124
zap.Duration("deadlineDuration", config.Deadline),
124125
)
125126
default:
126-
logger.Error("failed to bootstrap", zap.Error(errors.Unwrap(err)))
127+
logger.Error("failed to bootstrap", zap.Error(err))
127128
}
128-
result.FinalError = errors.Unwrap(err).Error()
129+
result.FinalError = err.Error()
129130
exitCode = 1
130131
}
131132

client/internal/bootstrap/auth.go

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"context"
88
"encoding/base64"
99
"fmt"
10+
"net/http"
1011
"strings"
1112

1213
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud"
14+
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds"
1315
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
1416
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
1517
"github.com/Azure/go-autorest/autorest/adal"
@@ -30,15 +32,11 @@ const (
3032
)
3133

3234
// extractAccessTokenFunc extracts an oauth access token from the specified service principal token after a refresh, fake implementations given in unit tests.
33-
type extractAccessTokenFunc func(token *adal.ServicePrincipalToken) (string, error)
35+
type extractAccessTokenFunc func(token *adal.ServicePrincipalToken, isMSI bool) (string, error)
3436

35-
func extractAccessToken(token *adal.ServicePrincipalToken) (string, error) {
37+
func extractAccessToken(token *adal.ServicePrincipalToken, isMSI bool) (string, error) {
3638
if err := token.Refresh(); err != nil {
37-
return "", &bootstrapError{
38-
errorType: ErrorTypeGetAccessTokenFailure,
39-
retryable: true,
40-
inner: fmt.Errorf("obtaining fresh access token: %w", err),
41-
}
39+
return "", tokenRefreshErrorToGetAccessTokenFailure(err, isMSI)
4240
}
4341
return token.OAuthToken(), nil
4442
}
@@ -56,56 +54,117 @@ func (c *client) getAccessToken(ctx context.Context, userAssignedIdentityID, res
5654

5755
if userAssignedID != "" {
5856
logger.Info("generating MSI access token", zap.String("clientId", userAssignedID))
59-
token, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{
57+
msiToken, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{
6058
ClientID: userAssignedID,
6159
})
6260
if err != nil {
6361
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating MSI access token: %w", err))
6462
}
6563
// to avoid falling too deep into exponential backoff implemented by adal, which follows the public retry guidance:
6664
// https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
67-
token.MaxMSIRefreshAttempts = maxMSIRefreshAttempts
68-
return c.extractAccessTokenFunc(token)
65+
msiToken.MaxMSIRefreshAttempts = maxMSIRefreshAttempts
66+
return c.extractAccessTokenFunc(msiToken, true)
6967
}
7068

7169
if cloudProviderConfig.ClientID == clientIDForMSI {
7270
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("client ID within cloud provider config indicates usage of a managed identity, though no user-assigned identity ID was provided"))
7371
}
7472

73+
servicePrincipalToken, err := getServicePrincipalToken(ctx, resource, cloudProviderConfig)
74+
if err != nil {
75+
return "", err
76+
}
77+
78+
return c.extractAccessTokenFunc(servicePrincipalToken, false)
79+
}
80+
81+
func getServicePrincipalToken(ctx context.Context, resource string, cloudProviderConfig *cloud.ProviderConfig) (*adal.ServicePrincipalToken, error) {
82+
logger := log.MustGetLogger(ctx)
83+
84+
secret := maybeB64Decode(cloudProviderConfig.ClientSecret)
85+
7586
env, err := azure.EnvironmentFromName(cloudProviderConfig.CloudName)
7687
if err != nil {
77-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("getting azure environment config for cloud %q: %w", cloudProviderConfig.CloudName, err))
88+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("getting azure environment config for cloud %q: %w", cloudProviderConfig.CloudName, err))
7889
}
7990
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, cloudProviderConfig.TenantID)
8091
if err != nil {
81-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("creating oauth config with azure environment: %w", err))
92+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("creating oauth config with azure environment: %w", err))
8293
}
8394

84-
if !strings.HasPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix) {
85-
logger.Info("generating SPN access token with username and password", zap.String("clientId", cloudProviderConfig.ClientID))
86-
token, err := adal.NewServicePrincipalToken(*oauthConfig, cloudProviderConfig.ClientID, cloudProviderConfig.ClientSecret, resource)
95+
if !strings.HasPrefix(secret, certificateSecretPrefix) {
96+
logger.Info("generating service principal access token with client secret", zap.String("clientId", cloudProviderConfig.ClientID))
97+
token, err := adal.NewServicePrincipalToken(*oauthConfig, cloudProviderConfig.ClientID, secret, resource)
8798
if err != nil {
88-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating SPN access token with username and password: %w", err))
99+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating service principal access token with client secret: %w", err))
89100
}
90-
return c.extractAccessTokenFunc(token)
101+
return token, nil
91102
}
92103

93-
logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", cloudProviderConfig.ClientID))
104+
logger.Info("client secret contains certificate data, using certificate to generate service principal access token", zap.String("clientId", cloudProviderConfig.ClientID))
94105

95-
certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix))
106+
certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(secret, certificateSecretPrefix))
96107
if err != nil {
97-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("b64-decoding certificate data in client secret: %w", err))
108+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("b64-decoding certificate data in client secret: %w", err))
98109
}
99110
certificate, privateKey, err := adal.DecodePfxCertificateData(certData, "")
100111
if err != nil {
101-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("decoding pfx certificate data in client secret: %w", err))
112+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("decoding pfx certificate data in client secret: %w", err))
102113
}
103114

104-
logger.Info("generating SPN access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID))
115+
logger.Info("generating service principal access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID))
105116
token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, cloudProviderConfig.ClientID, certificate, privateKey, resource)
106117
if err != nil {
107-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating SPN access token with certificate: %w", err))
118+
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating service principal access token with certificate: %w", err))
119+
}
120+
121+
return token, nil
122+
}
123+
124+
func maybeB64Decode(str string) string {
125+
if decoded, err := base64.StdEncoding.DecodeString(str); err == nil {
126+
return string(decoded)
127+
}
128+
return str
129+
}
130+
131+
func makeNonRetryableGetAccessTokenFailure(err error) error {
132+
return &bootstrapError{
133+
errorType: ErrorTypeGetAccessTokenFailure,
134+
retryable: false,
135+
inner: err,
136+
}
137+
}
138+
139+
func tokenRefreshErrorToGetAccessTokenFailure(err error, isMSI bool) error {
140+
bootstrapErr := &bootstrapError{
141+
errorType: ErrorTypeGetAccessTokenFailure,
142+
retryable: true, // optimistically consider the error retryable from the start
143+
inner: fmt.Errorf("obtaining fresh access token: %w", err),
144+
}
145+
146+
rerr, ok := err.(adal.TokenRefreshError)
147+
if !ok {
148+
return bootstrapErr
149+
}
150+
151+
resp := rerr.Response()
152+
if resp == nil {
153+
return bootstrapErr
154+
}
155+
156+
if !isMSI {
157+
bootstrapErr.retryable = resp.StatusCode >= http.StatusInternalServerError
158+
return bootstrapErr
159+
}
160+
161+
if resp.StatusCode != http.StatusBadRequest {
162+
bootstrapErr.retryable = imds.IsRetryableHTTPStatusCode(resp.StatusCode)
163+
return bootstrapErr
108164
}
109165

110-
return c.extractAccessTokenFunc(token)
166+
// 400s aren't normally retryable, though identity assignment can sometimes take a bit of time to propagate to IMDS,
167+
// so we treat "Identity not found" errors as retryable
168+
bootstrapErr.retryable = strings.Contains(strings.ToLower(err.Error()), "identity not found")
169+
return bootstrapErr
111170
}

0 commit comments

Comments
 (0)