Skip to content

Commit 3377b74

Browse files
refactor(client): allow specifying RPC deadlines and streamline retry logic (#162)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: cameronmeissner <24923771+cameronmeissner@users.noreply.github.com>
1 parent 1bd444e commit 3377b74

25 files changed

Lines changed: 724 additions & 707 deletions

.pipelines/client/e2e.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ stages:
8282
ADO_PAT: $(PAT-aksdevassistant)
8383
ORGANIZATION: $(ORGANIZATION)
8484
PROJECT: $(PROJECT)
85-
condition: succeeded()
85+
condition: and(succeeded(), ne(variables.SKIP_E2E, 'True'))
8686
displayName: ADO Login
8787
8888
- bash: /bin/bash .pipelines/client/scripts/run-pipeline.sh
@@ -92,7 +92,7 @@ stages:
9292
ADO_ORGANIZATION: $(ORGANIZATION)
9393
ADO_PROJECT: $(PROJECT)
9494
SUITE_ID: $(SECURE_TLS_BOOTSTRAPPING_CHECKIN_SUITE_ID)
95-
ADDITIONAL_VARS: "VSTS_TOGGLE_OVERRIDES=$(ENABLE_SECURE_TLS_BOOTSTRAPPING_TOGGLE_SETTINGS),$(TOGGLE_OVERRIDES)"
96-
condition: succeeded()
95+
ADDITIONAL_VARS: "$(TESTS_TO_RUN) VSTS_TOGGLE_OVERRIDES=$(ENABLE_SECURE_TLS_BOOTSTRAPPING_TOGGLE_SETTINGS),$(TOGGLE_OVERRIDES)"
96+
condition: and(succeeded(), ne(variables.SKIP_E2E, 'True'))
9797
displayName: Run Secure TLS Bootstrapping Check-in Tests
9898
timeoutInMinutes: 0

client/cmd/client/main.go

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515

1616
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/bootstrap"
1717
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/build"
18-
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/kubeconfig"
1918
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
2019
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
2120
"go.uber.org/zap"
@@ -45,7 +44,13 @@ func init() {
4544
flag.StringVar(&config.TLSMinVersion, "tls-min-version", "", "the minimum TLS version used to communicate with control plane")
4645
flag.BoolVar(&config.InsecureSkipTLSVerify, "insecure-skip-tls-verify", false, "skip TLS verification when connecting to the control plane")
4746
flag.BoolVar(&config.EnsureAuthorizedClient, "ensure-authorized", false, "ensure the specified kubeconfig contains an authorized clientset before bootstrapping")
48-
flag.DurationVar(&config.Deadline, "deadline", 0, "the deadline within which bootstrapping must succeed")
47+
flag.DurationVar(&config.ValidateKubeconfigTimeout, "validate-kubeconfig-timeout", 0, "timeout applied to existing kubeconfig validation")
48+
flag.DurationVar(&config.GetAccessTokenTimeout, "get-access-token-timeout", 0, "timeout applied to the get access token RPC")
49+
flag.DurationVar(&config.GetInstanceDataTimeout, "get-instance-data-timeout", 0, "timeout applied to the get instance data RPC")
50+
flag.DurationVar(&config.GetNonceTimeout, "get-nonce-timeout", 0, "timeout applied to the get nonce RPC")
51+
flag.DurationVar(&config.GetAttestedDataTimeout, "get-attested-data-timeout", 0, "timeout applied to the get attested data RPC")
52+
flag.DurationVar(&config.GetCredentialTimeout, "get-credential-timeout", 0, "timeout applied to the get credential RPC")
53+
flag.DurationVar(&config.Deadline, "deadline", 0, "the deadline within which bootstrapping must succeed - DEPRECATED, use RPC timeouts instead")
4954

5055
flag.Usage = func() {
5156
_, _ = fmt.Fprintf(os.Stderr, "Usage of %s - %s:\n", os.Args[0], build.GetVersion())
@@ -83,61 +88,45 @@ func run(ctx context.Context) int {
8388

8489
ctx = log.WithLogger(telemetry.WithTracing(ctx), logger)
8590

86-
var endTime time.Time
91+
logger.Info("running with config", config.ToZapFields()...)
92+
93+
var bootstrapErr error
94+
var startTime, endTime time.Time
8795
result := &bootstrap.Result{
8896
Status: bootstrap.StatusSuccess,
8997
}
90-
91-
startTime := time.Now()
92-
deadline := startTime.Add(config.Deadline)
93-
bootstrapCtx, cancel := context.WithDeadline(ctx, deadline)
94-
defer cancel()
95-
logger.Info("set bootstrap deadline", zap.Time("deadline", deadline))
96-
97-
kubeconfigPath := config.KubeconfigPath
98-
err = kubeconfig.NewValidator().Validate(bootstrapCtx, kubeconfigPath, config.EnsureAuthorizedClient)
99-
if err == nil {
100-
logger.Info("existing kubeconfig is valid, will not bootstrap a new kubelet client credential", zap.String("kubeconfig", kubeconfigPath))
101-
endTime = time.Now()
98+
defer func() {
99+
if bootstrapErr != nil {
100+
result.Status = bootstrap.StatusFailure
101+
result.FinalErrorType = bootstrap.GetErrorType(bootstrapErr)
102+
result.FinalError = bootstrapErr.Error()
103+
}
104+
result.Trace = telemetry.GetTrace(ctx)
102105
emitGuestAgentEvent(logger, startTime, endTime, result)
103-
return 0
104-
}
105-
logger.Info("failed to validate existing kubeconfig, will bootstrap a new kubelet client credential", zap.String("kubeconfig", kubeconfigPath), zap.Error(err))
106+
}()
106107

107-
errLog, traces, err := bootstrap.Bootstrap(bootstrapCtx, config)
108+
startTime = time.Now()
109+
bootstrapErr = bootstrap.Bootstrap(ctx, config)
108110
endTime = time.Now()
109-
result.Errors = errLog
110-
result.Traces = traces.GetLastNTraces(5) // only keep the last 5 traces to avoid truncating guest agent event data
111-
result.TraceSummary = traces.GetTraceSummary()
112111

113112
var exitCode int
114-
if err != nil {
115-
result.Status = bootstrap.StatusFailure
113+
if bootstrapErr != nil {
116114
switch {
117-
case errors.Is(err, context.Canceled):
118-
logger.Error("context was cancelled before bootstrapping could complete")
119-
case errors.Is(err, context.DeadlineExceeded):
120-
err = errors.Unwrap(err)
121-
logger.Error(
122-
"failed to successfully bootstrap before the specified deadline",
123-
zap.Error(err),
124-
zap.Time("deadline", deadline),
125-
zap.Duration("deadlineDuration", config.Deadline),
126-
)
115+
case errors.Is(bootstrapErr, context.Canceled):
116+
logger.Error("context was canceled before bootstrapping could complete")
117+
case errors.Is(bootstrapErr, context.DeadlineExceeded):
118+
logger.Error("failed to bootstrap due to exceeding context deadline", zap.Error(bootstrapErr))
127119
default:
128-
logger.Error("failed to bootstrap", zap.Error(err))
120+
logger.Error("failed to bootstrap", zap.Error(bootstrapErr))
129121
}
130-
result.FinalError = err.Error()
131122
exitCode = 1
132123
}
133124

134-
emitGuestAgentEvent(logger, startTime, endTime, result)
135125
return exitCode
136126
}
137127

138128
func emitGuestAgentEvent(logger *zap.Logger, startTime, endTime time.Time, result *bootstrap.Result) {
139129
result.ElapsedMilliseconds = endTime.Sub(startTime).Milliseconds()
140-
141130
bootstrapEvent := &bootstrap.Event{
142131
Start: startTime,
143132
End: endTime,

client/go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ require (
66
github.com/Azure/aks-secure-tls-bootstrap/service v1.0.2
77
github.com/Azure/go-autorest/autorest v0.11.29
88
github.com/Azure/go-autorest/autorest/adal v0.9.22
9-
github.com/avast/retry-go/v4 v4.6.1
109
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2
1110
github.com/hashicorp/go-retryablehttp v0.7.7
1211
github.com/stretchr/testify v1.10.0

client/go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z
1919
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
2020
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
2121
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
22-
github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk=
23-
github.com/avast/retry-go/v4 v4.6.1/go.mod h1:V6oF8njAwxJ5gRo1Q7Cxab24xs5NCWZBeaHHBklR8mA=
2422
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
2523
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
2624
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=

client/hack/linux/install.sh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set -euxo pipefail
33

44
# this script can be used to install an arbitrary version of aks-secure-tls-bootstrap-client
5-
# on a running AKS node for development/testing.
5+
# on a running AKS Linux node for development/testing.
66

77
# download and usage:
88
# 1. $ curl -o install-aks-secure-tls-bootstrap-client.sh https://raw.githubusercontent.com/Azure/aks-secure-tls-bootstrap/refs/heads/main/client/hack/linux/install.sh
@@ -14,12 +14,18 @@ VERSION="${VERSION:-}"
1414
[ -z "$VERSION" ] && echo "VERSION must be specified" && exit 1
1515
[ -z "$STORAGE_ACCOUNT_NAME" ] && echo "STORAGE_ACCOUNT_NAME must be specified" && exit 1
1616

17-
curl -fsSL https://${STORAGE_ACCOUNT_NAME}.z22.web.core.windows.net/client/linux/amd64/${VERSION} -o linux-amd64.tar.gz
17+
curl -fsSL https://${STORAGE_ACCOUNT_NAME}.z22.web.core.windows.net/client/linux/amd64/${VERSION}.tar.gz -o linux-amd64.tar.gz
1818
mkdir -p client
1919
tar -xvzf linux-amd64.tar.gz -C client
20-
rm /usr/local/bin/aks-secure-tls-bootstrap-client
2120
chmod +x client/aks-secure-tls-bootstrap-client
22-
mv client/aks-secure-tls-bootstrap-client /usr/local/bin/aks-secure-tls-bootstrap-client
21+
22+
rm -f /usr/local/bin/aks-secure-tls-bootstrap-client
23+
rm -f /opt/bin/aks-secure-tls-bootstrap-client
24+
mv client/aks-secure-tls-bootstrap-client /opt/bin/aks-secure-tls-bootstrap-client
25+
cp /opt/bin/aks-secure-tls-bootstrap-client /usr/local/bin/aks-secure-tls-bootstrap-client
26+
2327
rm -rf client
24-
rm linux-amd64.tar.gz
25-
stat /usr/local/bin/aks-secure-tls-bootstrap-client
28+
rm -f linux-amd64.tar.gz
29+
30+
stat /opt/bin/aks-secure-tls-bootstrap-client
31+
/opt/bin/aks-secure-tls-bootstrap-client -h

client/hack/windows/install.ps1

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ param(
1717

1818
Write-Host "Downloading aks-secure-tls-bootstrap-client version $Version from storage account $StorageAccountName"
1919

20-
$downloadUrl = "https://$StorageAccountName.z22.web.core.windows.net/client/windows/amd64/$Version"
20+
$downloadUrl = "https://$StorageAccountName.z22.web.core.windows.net/client/windows/amd64/$Version.zip"
2121
$archivePath = "windows-amd64.zip"
2222

2323
Write-Host "Downloading from: $downloadUrl"

client/internal/bootstrap/auth.go

Lines changed: 28 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,75 +7,73 @@ import (
77
"context"
88
"encoding/base64"
99
"fmt"
10-
"net/http"
1110
"strings"
1211

1312
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud"
14-
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds"
1513
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
16-
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
1714
"github.com/Azure/go-autorest/autorest/adal"
1815
"github.com/Azure/go-autorest/autorest/azure"
1916
"go.uber.org/zap"
2017
)
2118

2219
const (
20+
// service principal secrets containing this prefix are PFX certificates which need to be decoded,
21+
// rather than raw password / secret strings
2322
certificateSecretPrefix = "certificate:"
2423
)
2524

2625
const (
26+
// this will be the exact value of the "userAssignedIdentityID" field of the cloud provider config (azure.json)
27+
// when the node is using a (user-assigned) managed identity, rather than a service principal
2728
clientIDForMSI = "msi"
2829
)
2930

3031
const (
31-
maxMSIRefreshAttempts = 3
32+
// to avoid falling too deep into exponential backoff implemented by adal, which follows the public retry guidance:
33+
// https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
34+
maxMSIRefreshAttempts = 5
3235
)
3336

3437
// extractAccessTokenFunc extracts an oauth access token from the specified service principal token after a refresh, fake implementations given in unit tests.
35-
type extractAccessTokenFunc func(token *adal.ServicePrincipalToken, isMSI bool) (string, error)
38+
type extractAccessTokenFunc func(ctx context.Context, token *adal.ServicePrincipalToken, isMSI bool) (string, error)
3639

37-
func extractAccessToken(token *adal.ServicePrincipalToken, isMSI bool) (string, error) {
38-
if err := token.Refresh(); err != nil {
39-
return "", tokenRefreshErrorToGetAccessTokenFailure(err, isMSI)
40+
func extractAccessToken(ctx context.Context, token *adal.ServicePrincipalToken, isMSI bool) (string, error) {
41+
if err := token.RefreshWithContext(ctx); err != nil {
42+
return "", err
4043
}
4144
return token.OAuthToken(), nil
4245
}
4346

44-
func (c *client) getAccessToken(ctx context.Context, userAssignedIdentityID, resource string, cloudProviderConfig *cloud.ProviderConfig) (string, error) {
45-
endSpan := telemetry.StartSpan(ctx, "GetAccessToken")
46-
defer endSpan()
47-
47+
func (c *client) getToken(ctx context.Context, config *Config) (string, error) {
4848
logger := log.MustGetLogger(ctx)
4949

50-
userAssignedID := cloudProviderConfig.UserAssignedIdentityID
51-
if userAssignedIdentityID != "" {
52-
userAssignedID = userAssignedIdentityID
50+
userAssignedID := config.CloudProviderConfig.UserAssignedIdentityID
51+
if config.UserAssignedIdentityID != "" {
52+
userAssignedID = config.UserAssignedIdentityID
5353
}
5454

5555
if userAssignedID != "" {
5656
logger.Info("generating MSI access token", zap.String("clientId", userAssignedID))
57-
msiToken, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{
57+
msiToken, err := adal.NewServicePrincipalTokenFromManagedIdentity(config.AADResource, &adal.ManagedIdentityOptions{
5858
ClientID: userAssignedID,
5959
})
6060
if err != nil {
61-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating MSI access token: %w", err))
61+
return "", fmt.Errorf("generating MSI access token: %w", err)
6262
}
63-
// to avoid falling too deep into exponential backoff implemented by adal, which follows the public retry guidance:
64-
// https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
6563
msiToken.MaxMSIRefreshAttempts = maxMSIRefreshAttempts
66-
return c.extractAccessTokenFunc(msiToken, true)
64+
return c.extractAccessTokenFunc(ctx, msiToken, true)
6765
}
6866

69-
if cloudProviderConfig.ClientID == clientIDForMSI {
70-
return "", makeNonRetryableGetAccessTokenFailure(fmt.Errorf("client ID within cloud provider config indicates usage of a managed identity, though no user-assigned identity ID was provided"))
67+
if config.CloudProviderConfig.ClientID == clientIDForMSI {
68+
return "", fmt.Errorf("client ID within cloud provider config indicates usage of a managed identity, though no user-assigned identity ID was provided")
7169
}
7270

73-
servicePrincipalToken, err := getServicePrincipalToken(ctx, resource, cloudProviderConfig)
71+
servicePrincipalToken, err := getServicePrincipalToken(ctx, config.AADResource, config.CloudProviderConfig)
7472
if err != nil {
7573
return "", err
7674
}
7775

78-
return c.extractAccessTokenFunc(servicePrincipalToken, false)
76+
return c.extractAccessTokenFunc(ctx, servicePrincipalToken, false)
7977
}
8078

8179
func getServicePrincipalToken(ctx context.Context, resource string, cloudProviderConfig *cloud.ProviderConfig) (*adal.ServicePrincipalToken, error) {
@@ -85,18 +83,18 @@ func getServicePrincipalToken(ctx context.Context, resource string, cloudProvide
8583

8684
env, err := azure.EnvironmentFromName(cloudProviderConfig.CloudName)
8785
if err != nil {
88-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("getting azure environment config for cloud %q: %w", cloudProviderConfig.CloudName, err))
86+
return nil, fmt.Errorf("getting azure environment config for cloud %q: %w", cloudProviderConfig.CloudName, err)
8987
}
9088
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, cloudProviderConfig.TenantID)
9189
if err != nil {
92-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("creating oauth config with azure environment: %w", err))
90+
return nil, fmt.Errorf("creating oauth config with azure environment: %w", err)
9391
}
9492

9593
if !strings.HasPrefix(secret, certificateSecretPrefix) {
9694
logger.Info("generating service principal access token with client secret", zap.String("clientId", cloudProviderConfig.ClientID))
9795
token, err := adal.NewServicePrincipalToken(*oauthConfig, cloudProviderConfig.ClientID, secret, resource)
9896
if err != nil {
99-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating service principal access token with client secret: %w", err))
97+
return nil, fmt.Errorf("generating service principal access token with client secret: %w", err)
10098
}
10199
return token, nil
102100
}
@@ -105,17 +103,17 @@ func getServicePrincipalToken(ctx context.Context, resource string, cloudProvide
105103

106104
certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(secret, certificateSecretPrefix))
107105
if err != nil {
108-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("b64-decoding certificate data in client secret: %w", err))
106+
return nil, fmt.Errorf("b64-decoding certificate data in client secret: %w", err)
109107
}
110108
certificate, privateKey, err := adal.DecodePfxCertificateData(certData, "")
111109
if err != nil {
112-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("decoding pfx certificate data in client secret: %w", err))
110+
return nil, fmt.Errorf("decoding pfx certificate data in client secret: %w", err)
113111
}
114112

115113
logger.Info("generating service principal access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID))
116114
token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, cloudProviderConfig.ClientID, certificate, privateKey, resource)
117115
if err != nil {
118-
return nil, makeNonRetryableGetAccessTokenFailure(fmt.Errorf("generating service principal access token with certificate: %w", err))
116+
return nil, fmt.Errorf("generating service principal access token with certificate: %w", err)
119117
}
120118

121119
return token, nil
@@ -127,44 +125,3 @@ func maybeB64Decode(str string) string {
127125
}
128126
return str
129127
}
130-
131-
func makeNonRetryableGetAccessTokenFailure(err error) error {
132-
return &bootstrapError{
133-
errorType: ErrorTypeGetAccessTokenFailure,
134-
retryable: false,
135-
inner: err,
136-
}
137-
}
138-
139-
func tokenRefreshErrorToGetAccessTokenFailure(err error, isMSI bool) error {
140-
bootstrapErr := &bootstrapError{
141-
errorType: ErrorTypeGetAccessTokenFailure,
142-
retryable: true, // optimistically consider the error retryable from the start
143-
inner: fmt.Errorf("obtaining fresh access token: %w", err),
144-
}
145-
146-
rerr, ok := err.(adal.TokenRefreshError)
147-
if !ok {
148-
return bootstrapErr
149-
}
150-
151-
resp := rerr.Response()
152-
if resp == nil {
153-
return bootstrapErr
154-
}
155-
156-
if !isMSI {
157-
bootstrapErr.retryable = resp.StatusCode >= http.StatusInternalServerError
158-
return bootstrapErr
159-
}
160-
161-
if resp.StatusCode != http.StatusBadRequest {
162-
bootstrapErr.retryable = imds.IsRetryableHTTPStatusCode(resp.StatusCode)
163-
return bootstrapErr
164-
}
165-
166-
// 400s aren't normally retryable, though identity assignment can sometimes take a bit of time to propagate to IMDS,
167-
// so we treat "Identity not found" errors as retryable
168-
bootstrapErr.retryable = strings.Contains(strings.ToLower(err.Error()), "identity not found")
169-
return bootstrapErr
170-
}

0 commit comments

Comments
 (0)