-
Notifications
You must be signed in to change notification settings - Fork 403
Expand file tree
/
Copy pathManagedIdentityAuthRequest.cs
More file actions
285 lines (242 loc) · 14.5 KB
/
ManagedIdentityAuthRequest.cs
File metadata and controls
285 lines (242 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.AuthScheme.PoP;
using Microsoft.Identity.Client.Cache.Items;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Client.PlatformsCommon.Interfaces;
using Microsoft.Identity.Client.Utils;
namespace Microsoft.Identity.Client.Internal.Requests
{
internal class ManagedIdentityAuthRequest : RequestBase
{
private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters;
private readonly ManagedIdentityClient _managedIdentityClient;
private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1);
private readonly ICryptographyManager _cryptoManager;
private readonly IManagedIdentityKeyProvider _managedIdentityKeyProvider;
public ManagedIdentityAuthRequest(
IServiceBundle serviceBundle,
AuthenticationRequestParameters authenticationRequestParameters,
AcquireTokenForManagedIdentityParameters managedIdentityParameters,
ManagedIdentityClient managedIdentityClient)
: base(serviceBundle, authenticationRequestParameters, managedIdentityParameters)
{
_managedIdentityParameters = managedIdentityParameters;
_managedIdentityClient = managedIdentityClient;
_cryptoManager = serviceBundle.PlatformProxy.CryptographyManager;
_managedIdentityKeyProvider = serviceBundle.PlatformProxy.ManagedIdentityKeyProvider;
}
protected override async Task<AuthenticationResult> ExecuteAsync(CancellationToken cancellationToken)
{
AuthenticationResult authResult = null;
ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger;
// Prime the scheme before any cache lookup if we already have a binding cert from a prior mint
if (AuthenticationRequestParameters.IsMtlsPopRequested)
{
if (_managedIdentityClient.RuntimeMtlsBindingCertificate != null)
{
AuthenticationRequestParameters.AuthenticationScheme = new MtlsPopAuthenticationOperation(_managedIdentityClient.RuntimeMtlsBindingCertificate);
logger.Info("[ManagedIdentity] Using prior mTLS binding certificate for cache lookup.");
logger.InfoPii(
() => $"[ManagedIdentity][PII] Prior mTLS cert thumbprint: {_managedIdentityClient.RuntimeMtlsBindingCertificate.Thumbprint}",
() => "[ManagedIdentity][PII] Prior mTLS cert thumbprint: ***");
}
}
// 1. FIRST, handle ForceRefresh
if (_managedIdentityParameters.ForceRefresh)
{
//log a warning if Claims are also set
if (!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
{
logger.Warning("[ManagedIdentityRequest] Both ForceRefresh and Claims are set. Using ForceRefresh to skip cache.");
}
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
logger.Info("[ManagedIdentityRequest] Skipped using the cache because ForceRefresh was set.");
// Straight to the MI endpoint
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
return authResult;
}
// 2. Otherwise, look for a cached token
MsalAccessTokenCacheItem cachedAccessTokenItem = await GetCachedAccessTokenAsync()
.ConfigureAwait(false);
// If we have claims, we do NOT use the cached token (but we still need it to compute the hash).
if (!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
{
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
// If there is a cached token, compute its hash for the “revoked token” scenario
if (cachedAccessTokenItem != null)
{
string cachedTokenHash = _cryptoManager.CreateSha256HashHex(cachedAccessTokenItem.Secret);
_managedIdentityParameters.RevokedTokenHash = cachedTokenHash;
logger.Info("[ManagedIdentityRequest] Claims are present. Computed hash of the cached (revoked) token. " +
"Will now request a fresh token from the MI endpoint.");
}
else
{
logger.Info("[ManagedIdentityRequest] Claims are present, but no cached token was found. " +
"Requesting a fresh token from the MI endpoint without a revoked-token hash.");
}
// In both cases, we skip using the cached token and get a new one
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
return authResult;
}
// 3. If we have no ForceRefresh and no claims, we can use the cache
if (cachedAccessTokenItem != null)
{
authResult = await CreateAuthenticationResultFromCacheAsync(cachedAccessTokenItem, cancellationToken).ConfigureAwait(false);
logger.Info("[ManagedIdentityRequest] Access token retrieved from cache.");
try
{
var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);
// If needed, refreshes token in the background
if (proactivelyRefresh)
{
logger.Info("[ManagedIdentityRequest] Initiating a proactive refresh.");
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ProactivelyRefreshed;
SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() =>
{
// Use a linked token source, in case the original cancellation token source is disposed before this background task completes.
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return GetAccessTokenAsync(tokenSource.Token, logger);
}, logger, ServiceBundle, AuthenticationRequestParameters.RequestContext.ApiEvent,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion);
}
}
catch (MsalServiceException e)
{
// If background refresh fails, we handle the exception
return await HandleTokenRefreshErrorAsync(e, cachedAccessTokenItem, cancellationToken).ConfigureAwait(false);
}
}
else
{
// No cached token
if (AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo != CacheRefreshReason.Expired)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.NoCachedAccessToken;
}
logger.Info("[ManagedIdentityRequest] No cached access token found. " +
"Getting a token from the managed identity endpoint.");
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
}
return authResult;
}
private async Task<AuthenticationResult> GetAccessTokenAsync(
CancellationToken cancellationToken,
ILoggerAdapter logger)
{
AuthenticationResult authResult;
MsalAccessTokenCacheItem cachedAccessTokenItem = null;
// Requests to a managed identity endpoint must be throttled;
// otherwise, the endpoint will throw a HTTP 429.
logger.Verbose(() => "[ManagedIdentityRequest] Entering managed identity request semaphore.");
await s_semaphoreSlim.WaitAsync(cancellationToken).ConfigureAwait(false);
logger.Verbose(() => "[ManagedIdentityRequest] Entered managed identity request semaphore.");
try
{
// While holding the semaphore, decide whether to bypass the cache.
// Re-check because another thread may have filled the cache while we waited.
// Bypass when:
// 1) ForceRefresh is requested
// 2) Proactive refresh is in effect
// 3) Claims are present (revocation flow)
if (_managedIdentityParameters.ForceRefresh ||
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo == CacheRefreshReason.ProactivelyRefreshed ||
!string.IsNullOrEmpty(_managedIdentityParameters.Claims))
{
authResult = await SendTokenRequestForManagedIdentityAsync(logger, cancellationToken).ConfigureAwait(false);
}
else
{
logger.Info("[ManagedIdentityRequest] Checking for a cached access token.");
cachedAccessTokenItem = await GetCachedAccessTokenAsync().ConfigureAwait(false);
// Check the cache again after acquiring the semaphore in case the previous request cached a new token.
if (cachedAccessTokenItem != null)
{
authResult = await CreateAuthenticationResultFromCacheAsync(cachedAccessTokenItem, cancellationToken).ConfigureAwait(false);
}
else
{
authResult = await SendTokenRequestForManagedIdentityAsync(logger, cancellationToken).ConfigureAwait(false);
}
}
return authResult;
}
finally
{
s_semaphoreSlim.Release();
logger.Verbose(() => "[ManagedIdentityRequest] Released managed identity request semaphore.");
}
}
private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync(ILoggerAdapter logger, CancellationToken cancellationToken)
{
logger.Info("[ManagedIdentityRequest] Acquiring a token from the managed identity endpoint.");
await ResolveAuthorityAsync().ConfigureAwait(false);
_managedIdentityParameters.IsMtlsPopRequested = AuthenticationRequestParameters.IsMtlsPopRequested;
// Propagate client-originated claims to the MI parameters for transport.
// Unlike server-issued Claims (which bypass the cache), ClientClaims are cached normally.
if (!string.IsNullOrEmpty(AuthenticationRequestParameters.ClientClaims))
{
_managedIdentityParameters.ClientClaims = AuthenticationRequestParameters.ClientClaims;
}
ManagedIdentityResponse managedIdentityResponse =
await _managedIdentityClient
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)
.ConfigureAwait(false);
if (AuthenticationRequestParameters.IsMtlsPopRequested && _managedIdentityParameters.MtlsCertificate != null)
{
// Remember the cert...
_managedIdentityClient.SetRuntimeMtlsBindingCertificate(_managedIdentityParameters.MtlsCertificate);
// Apply mTLS scheme BEFORE caching...
AuthenticationRequestParameters.AuthenticationScheme =
new MtlsPopAuthenticationOperation(_managedIdentityParameters.MtlsCertificate);
_managedIdentityParameters.MtlsCertificate = null;
AuthenticationRequestParameters.RequestContext.Logger.Info(
"[ManagedIdentity] Applied mtls_pop scheme prior to caching.");
}
var msalTokenResponse = MsalTokenResponse.CreateFromManagedIdentityResponse(managedIdentityResponse);
msalTokenResponse.Scope = AuthenticationRequestParameters.Scope.AsSingleString();
return await CacheTokenResponseAndCreateAuthenticationResultAsync(msalTokenResponse, cancellationToken).ConfigureAwait(false);
}
private async Task<MsalAccessTokenCacheItem> GetCachedAccessTokenAsync()
{
MsalAccessTokenCacheItem cachedAccessTokenItem = await CacheManager.FindAccessTokenAsync().ConfigureAwait(false);
if (cachedAccessTokenItem != null)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.IsAccessTokenCacheHit = true;
Metrics.IncrementTotalAccessTokensFromCache();
return cachedAccessTokenItem;
}
return null;
}
private Task<AuthenticationResult> CreateAuthenticationResultFromCacheAsync(
MsalAccessTokenCacheItem cachedAccessTokenItem, CancellationToken cancellationToken)
{
return AuthenticationResult.CreateAsync(
msalAccessTokenCacheItem: cachedAccessTokenItem,
msalIdTokenCacheItem: null, authenticationScheme: AuthenticationRequestParameters.AuthenticationScheme,
correlationId: AuthenticationRequestParameters.RequestContext.CorrelationId,
tokenSource: TokenSource.Cache,
apiEvent: AuthenticationRequestParameters.RequestContext.ApiEvent,
account: null,
spaAuthCode: null,
additionalResponseParameters: null,
cancellationToken: cancellationToken);
}
protected override KeyValuePair<string, string>? GetCcsHeader(IDictionary<string, string> additionalBodyParameters)
{
return null;
}
}
}