-
Notifications
You must be signed in to change notification settings - Fork 403
Expand file tree
/
Copy pathAbstractManagedIdentity.cs
More file actions
350 lines (296 loc) · 14.4 KB
/
AbstractManagedIdentity.cs
File metadata and controls
350 lines (296 loc) · 14.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Utils;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Core;
using System.Net;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using System.Text;
using System.Security.Cryptography.X509Certificates;
using System.Net.Security;
using Microsoft.Identity.Client.Http.Retry;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
namespace Microsoft.Identity.Client.ManagedIdentity
{
internal abstract class AbstractManagedIdentity
{
private const string ManagedIdentityPrefix = "[Managed Identity] ";
protected readonly RequestContext _requestContext;
protected bool _isMtlsPopRequested;
internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out.";
internal readonly ManagedIdentitySource _sourceType;
protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType)
{
_requestContext = requestContext;
_sourceType = sourceType;
}
public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
AcquireTokenForManagedIdentityParameters parameters,
CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
_requestContext.Logger.Error(TimeoutError);
cancellationToken.ThrowIfCancellationRequested();
}
HttpResponse response;
// Convert the scopes to a resource string.
string resource = parameters.Resource;
_isMtlsPopRequested = parameters.IsMtlsPopRequested;
ManagedIdentityRequest request = await CreateRequestAsync(resource).ConfigureAwait(false);
// Forward client-originated claims to the correct location:
// - GET requests (IMDS/MSIv1): append as "claims" query parameter
// - POST requests (ImdsV2 / ESTS): append as "claims" body parameter
if (!string.IsNullOrEmpty(parameters.ClientClaims))
{
if (request.Method == System.Net.Http.HttpMethod.Get)
{
request.QueryParameters["claims"] = Uri.EscapeDataString(parameters.ClientClaims);
_requestContext.Logger.Info("[Managed Identity] Adding client claims to IMDS request as query parameter.");
}
else
{
request.BodyParameters["claims"] = parameters.ClientClaims;
_requestContext.Logger.Info("[Managed Identity] Adding client claims to ESTS POST body.");
}
}
// When IMDSv2 mints a binding certificate during this request (via CSR),
// it's exposed via request.MtlsCertificate. Bubble it up so the request
// layer can set the mtls_pop scheme
if (parameters.IsMtlsPopRequested && request?.MtlsCertificate != null)
{
parameters.MtlsCertificate = request.MtlsCertificate;
}
// Automatically add claims / capabilities if this MI source supports them
if (_sourceType.SupportsClaimsAndCapabilities())
{
request.AddClaimsAndCapabilities(
_requestContext.ServiceBundle.Config.ClientCapabilities,
parameters,
_requestContext.Logger);
}
request.AddExtraQueryParams(
_requestContext.ServiceBundle.Config.ExtraQueryParameters,
_requestContext.Logger);
_requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints.");
IRetryPolicy retryPolicy = _requestContext.ServiceBundle.Config.RetryPolicyFactory.GetRetryPolicy(request.RequestType);
try
{
if (request.Method == HttpMethod.Get)
{
response = await _requestContext.ServiceBundle.HttpManager
.SendRequestAsync(
request.ComputeUri(),
request.Headers,
body: null,
method: HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: request.MtlsCertificate,
validateServerCertificate: GetValidationCallback(),
cancellationToken: cancellationToken,
retryPolicy: retryPolicy).ConfigureAwait(false);
}
else
{
response = await _requestContext.ServiceBundle.HttpManager
.SendRequestAsync(
request.ComputeUri(),
request.Headers,
body: new FormUrlEncodedContent(request.BodyParameters),
method: HttpMethod.Post,
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: request.MtlsCertificate,
validateServerCertificate: GetValidationCallback(),
cancellationToken: cancellationToken,
retryPolicy: retryPolicy)
.ConfigureAwait(false);
}
return await HandleResponseAsync(parameters, response, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
HandleException(ex);
throw;
}
}
/// <summary>
/// Method to be overridden in the derived classes to provide a custom validation callback for the server certificate.
/// This validation is needed for service fabric managed identity endpoints.
/// </summary>
/// <returns>Callback to validate the server certificate.</returns>
internal virtual Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> GetValidationCallback()
{
return null;
}
protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
AcquireTokenForManagedIdentityParameters parameters,
HttpResponse response,
CancellationToken cancellationToken)
{
if (response.StatusCode == HttpStatusCode.OK)
{
_requestContext.Logger.Info("[Managed Identity] Successful response received.");
return Task.FromResult(GetSuccessfulResponse(response));
}
string message = GetMessageFromErrorResponse(response);
_requestContext.Logger.Error($"[Managed Identity] request failed, HttpStatusCode: {response.StatusCode} Error message: {message}");
MsalException exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.ManagedIdentityRequestFailed,
message,
null,
_sourceType,
(int)response.StatusCode);
throw exception;
}
protected abstract Task<ManagedIdentityRequest> CreateRequestAsync(string resource);
protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response)
{
ManagedIdentityResponse managedIdentityResponse;
try
{
managedIdentityResponse = JsonHelper.DeserializeFromJson<ManagedIdentityResponse>(response.Body);
}
catch (JsonException ex)
{
_requestContext.Logger.Error("[Managed Identity] MSI json response failed to parse. " + ex);
var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.ManagedIdentityResponseParseFailure,
MsalErrorMessage.ManagedIdentityJsonParseFailure,
ex,
_sourceType,
(int)HttpStatusCode.OK);
throw exception;
}
if (managedIdentityResponse == null ||
managedIdentityResponse.AccessToken.IsNullOrEmpty() ||
managedIdentityResponse.ExpiresOn.IsNullOrEmpty())
{
_requestContext.Logger.Error("[Managed Identity] Response is either null or insufficient for authentication.");
var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.ManagedIdentityRequestFailed,
MsalErrorMessage.ManagedIdentityInvalidResponse,
null,
_sourceType,
(int)HttpStatusCode.OK);
throw exception;
}
return managedIdentityResponse;
}
internal string GetMessageFromErrorResponse(HttpResponse response)
{
if (string.IsNullOrEmpty(response?.Body))
{
return MsalErrorMessage.ManagedIdentityNoResponseReceived;
}
try
{
ManagedIdentityErrorResponse managedIdentityErrorResponse = JsonHelper.DeserializeFromJson<ManagedIdentityErrorResponse>(response?.Body);
return ExtractErrorMessageFromManagedIdentityErrorResponse(managedIdentityErrorResponse);
}
catch
{
return TryGetMessageFromNestedErrorResponse(response.Body);
}
}
private string ExtractErrorMessageFromManagedIdentityErrorResponse(ManagedIdentityErrorResponse managedIdentityErrorResponse)
{
StringBuilder stringBuilder = new StringBuilder(ManagedIdentityPrefix);
if (!string.IsNullOrEmpty(managedIdentityErrorResponse.Error))
{
stringBuilder.Append($"Error Code: {managedIdentityErrorResponse.Error} ");
}
if (!string.IsNullOrEmpty(managedIdentityErrorResponse.Message))
{
stringBuilder.Append($"Error Message: {managedIdentityErrorResponse.Message} ");
}
if (!string.IsNullOrEmpty(managedIdentityErrorResponse.ErrorDescription))
{
stringBuilder.Append($"Error Description: {managedIdentityErrorResponse.ErrorDescription} ");
}
if (!string.IsNullOrEmpty(managedIdentityErrorResponse.CorrelationId))
{
stringBuilder.Append($"Managed Identity Correlation ID: {managedIdentityErrorResponse.CorrelationId} Use this Correlation ID for further investigation.");
}
if (stringBuilder.Length == ManagedIdentityPrefix.Length)
{
return $"{MsalErrorMessage.ManagedIdentityUnexpectedErrorResponse}.";
}
return stringBuilder.ToString();
}
// Try to get the error message from the nested error response in case of cloud shell.
private string TryGetMessageFromNestedErrorResponse(string response)
{
try
{
var json = JsonHelper.ParseIntoJsonObject(response);
JsonHelper.TryGetValue(json, "error", out var error);
StringBuilder errorMessage = new StringBuilder(ManagedIdentityPrefix);
if (JsonHelper.TryGetValue(JsonHelper.ToJsonObject(error), "code", out var errorCode))
{
errorMessage.Append($"Error Code: {errorCode} ");
}
if (JsonHelper.TryGetValue(JsonHelper.ToJsonObject(error), "message", out var message))
{
errorMessage.Append($"Error Message: {message}");
}
if (message != null || errorCode != null)
{
return errorMessage.ToString();
}
}
catch
{
// Ignore any exceptions that occur during parsing and send the error message.
}
_requestContext.Logger.Error($"{MsalErrorMessage.ManagedIdentityUnexpectedErrorResponse}. Error response received from the server: {response}.");
return $"{MsalErrorMessage.ManagedIdentityUnexpectedErrorResponse}. Error response received from the server: {response}.";
}
private void HandleException(Exception ex,
ManagedIdentitySource managedIdentitySource = ManagedIdentitySource.None,
string additionalInfo = null)
{
ManagedIdentitySource source = managedIdentitySource != ManagedIdentitySource.None ? managedIdentitySource : _sourceType;
if (ex is HttpRequestException httpRequestException)
{
CreateAndThrowException(MsalError.ManagedIdentityUnreachableNetwork, httpRequestException.Message, httpRequestException, source);
}
else if (ex is TaskCanceledException)
{
_requestContext.Logger.Error(TimeoutError);
}
else if (ex is FormatException formatException)
{
string errorMessage = additionalInfo ?? formatException.Message;
_requestContext.Logger.Error($"[Managed Identity] Format Exception: {errorMessage}");
CreateAndThrowException(MsalError.InvalidManagedIdentityEndpoint, errorMessage, formatException, source);
}
else if (ex is not MsalServiceException)
{
_requestContext.Logger.Error($"[Managed Identity] Exception: {ex.Message}");
CreateAndThrowException(MsalError.ManagedIdentityRequestFailed, ex.Message, ex, source);
}
}
private static void CreateAndThrowException(string errorCode,
string errorMessage,
Exception innerException,
ManagedIdentitySource source)
{
MsalException exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
errorCode,
errorMessage,
innerException,
source,
null);
throw exception;
}
}
}