-
Notifications
You must be signed in to change notification settings - Fork 403
Expand file tree
/
Copy pathOAuth2Client.cs
More file actions
401 lines (354 loc) · 16.4 KB
/
OAuth2Client.cs
File metadata and controls
401 lines (354 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Globalization;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Instance.Discovery;
using Microsoft.Identity.Client.Instance.Oidc;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Utils;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Identity.Client.Http.Retry;
#if SUPPORTS_SYSTEM_TEXT_JSON
using System.Text.Json;
#else
using Microsoft.Identity.Json;
#endif
namespace Microsoft.Identity.Client.OAuth2
{
/// <summary>
/// Responsible for talking to all the Identity provider endpoints:
/// - instance discovery
/// - endpoint metadata
/// - mex
/// - /token endpoint via TokenClient
/// - device code endpoint
/// </summary>
internal class OAuth2Client
{
private readonly Dictionary<string, string> _headers;
private readonly Dictionary<string, string> _queryParameters = new Dictionary<string, string>();
private readonly IDictionary<string, string> _bodyParameters = new Dictionary<string, string>();
private readonly IHttpManager _httpManager;
private readonly X509Certificate2 _mtlsCertificate;
public OAuth2Client(ILoggerAdapter logger, IHttpManager httpManager, X509Certificate2 mtlsCertificate)
{
_headers = new Dictionary<string, string>(MsalIdHelper.GetMsalIdParameters(logger));
_httpManager = httpManager ?? throw new ArgumentNullException(nameof(httpManager));
_mtlsCertificate = mtlsCertificate;
}
public void AddQueryParameter(string key, string value)
{
if (!string.IsNullOrWhiteSpace(key) && !string.IsNullOrWhiteSpace(value))
{
_queryParameters[key] = value;
}
}
public void AddBodyParameter(string key, string value)
{
if (!string.IsNullOrWhiteSpace(key) && !string.IsNullOrWhiteSpace(value))
{
_bodyParameters[key] = value;
}
}
internal void AddHeader(string key, string value)
{
_headers[key] = value;
}
internal IReadOnlyDictionary<string, string> GetBodyParameters()
{
return new ReadOnlyDictionary<string, string>(_bodyParameters);
}
public Task<InstanceDiscoveryResponse> DiscoverAadInstanceAsync(Uri endpoint, RequestContext requestContext)
{
return ExecuteRequestAsync<InstanceDiscoveryResponse>(endpoint, HttpMethod.Get, requestContext);
}
public Task<OidcMetadata> DiscoverOidcMetadataAsync(Uri endpoint, RequestContext requestContext)
{
return ExecuteRequestAsync<OidcMetadata>(endpoint, HttpMethod.Get, requestContext);
}
internal Task<MsalTokenResponse> GetTokenAsync(
Uri endPoint,
RequestContext requestContext,
bool addCommonHeaders,
IList<Func<OnBeforeTokenRequestData, Task>> onBeforePostRequestHandler)
{
return ExecuteRequestAsync<MsalTokenResponse>(
endPoint,
HttpMethod.Post,
requestContext,
false,
addCommonHeaders,
onBeforePostRequestHandler);
}
internal async Task<T> ExecuteRequestAsync<T>(
Uri endPoint,
HttpMethod method,
RequestContext requestContext,
bool expectErrorsOn200OK = false,
bool addCommonHeaders = true,
IList<Func<OnBeforeTokenRequestData, Task>> onBeforePostRequestHandlers = null)
{
//Requests that are replayed by PKeyAuth do not need to have headers added because they already exist
if (addCommonHeaders)
{
AddCommonHeaders(requestContext);
}
HttpResponse response;
Uri endpointUri = AddExtraQueryParams(endPoint);
using (requestContext.Logger.LogBlockDuration($"[Oauth2Client] Sending {method} request "))
{
IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory;
IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS);
try
{
if (method == HttpMethod.Post)
{
if (onBeforePostRequestHandlers != null)
{
requestContext.Logger.Verbose(() => "[Oauth2Client] Processing onBeforePostRequestData ");
var requestData = new OnBeforeTokenRequestData(_bodyParameters, _headers, endpointUri, requestContext.UserCancellationToken);
foreach(var handler in onBeforePostRequestHandlers)
{
await handler(requestData).ConfigureAwait(false);
}
endpointUri = requestData.RequestUri;
}
response = await _httpManager.SendRequestAsync(
endpointUri,
_headers,
body: new FormUrlEncodedContent(_bodyParameters),
method: HttpMethod.Post,
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: _mtlsCertificate,
validateServerCertificate: null,
cancellationToken: requestContext.UserCancellationToken,
retryPolicy: retryPolicy)
.ConfigureAwait(false);
}
else
{
response = await _httpManager.SendRequestAsync(
endpointUri,
_headers,
body: null,
method: HttpMethod.Get,
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
validateServerCertificate: null,
cancellationToken: requestContext.UserCancellationToken,
retryPolicy: retryPolicy)
.ConfigureAwait(false);
}
}
catch (Exception ex)
{
if (ex is TaskCanceledException && requestContext.UserCancellationToken.IsCancellationRequested)
{
throw;
}
requestContext.Logger.ErrorPii(
string.Format(MsalErrorMessage.RequestFailureErrorMessagePii,
requestContext.ApiEvent?.ApiIdString,
$"{endpointUri.Scheme}://{endpointUri.Host}{endpointUri.AbsolutePath}",
requestContext.ServiceBundle.Config.ClientId),
string.Format(MsalErrorMessage.RequestFailureErrorMessage,
requestContext.ApiEvent?.ApiIdString,
$"{endpointUri.Scheme}://{endpointUri.Host}"));
requestContext.Logger.ErrorPii(ex);
throw;
}
}
if (requestContext.ApiEvent != null)
{
requestContext.ApiEvent.DurationInHttpInMs += _httpManager.LastRequestDurationInMs;
}
if (response.StatusCode != HttpStatusCode.OK || expectErrorsOn200OK)
{
requestContext.Logger.Verbose(() => "[Oauth2Client] Processing error response ");
try
{
// In cases where the end-point is not found (404) response.body will be empty.
// CreateResponse handles throwing errors - in the case of HttpStatusCode <> and ErrorResponse will be created.
if (!string.IsNullOrWhiteSpace(response.Body))
{
var msalTokenResponse = JsonHelper.DeserializeFromJson<MsalTokenResponse>(response.Body);
if (response.StatusCode == HttpStatusCode.OK &&
expectErrorsOn200OK &&
!string.IsNullOrEmpty(msalTokenResponse?.Error))
{
ThrowServerException(response, requestContext);
}
}
}
catch (JsonException) // in the rare case we get an error response we cannot deserialize
{
// CreateErrorResponse does the same validation. Will be logging the error there.
}
}
return CreateResponse<T>(response, requestContext);
}
internal void AddBodyParameter(KeyValuePair<string, string> kvp)
{
_bodyParameters.Add(kvp);
}
private void AddCommonHeaders(RequestContext requestContext)
{
_headers.Add(OAuth2Header.CorrelationId, requestContext.CorrelationId.ToString());
_headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true");
}
public static T CreateResponse<T>(HttpResponse response, RequestContext requestContext)
{
if (response.StatusCode != HttpStatusCode.OK)
{
ThrowServerException(response, requestContext);
}
VerifyCorrelationIdHeaderInResponse(response.HeadersAsDictionary, requestContext);
using (requestContext.Logger.LogBlockDuration("[OAuth2Client] Deserializing response"))
{
return JsonHelper.DeserializeFromJson<T>(response.Body);
}
}
private static void ThrowServerException(HttpResponse response, RequestContext requestContext)
{
bool shouldLogAsError = true;
var httpErrorCodeMessage = string.Format(CultureInfo.InvariantCulture, "HttpStatusCode: {0}: {1}", (int)response.StatusCode, response.StatusCode.ToString());
requestContext.Logger.Info(httpErrorCodeMessage);
MsalServiceException exceptionToThrow;
try
{
exceptionToThrow = ExtractErrorsFromTheResponse(response, ref shouldLogAsError, requestContext);
}
catch (JsonException) // in the rare case we get an error response we cannot deserialize
{
exceptionToThrow = MsalServiceExceptionFactory.FromHttpResponse(
MsalError.NonParsableOAuthError,
MsalErrorMessage.NonParsableOAuthError,
response,
null,
requestContext);
}
catch (Exception ex)
{
exceptionToThrow = MsalServiceExceptionFactory.FromHttpResponse(
MsalError.UnknownError,
response.Body,
response,
ex,
requestContext);
}
exceptionToThrow ??= MsalServiceExceptionFactory.FromHttpResponse(
response.StatusCode == HttpStatusCode.NotFound
? MsalError.HttpStatusNotFound
: MsalError.HttpStatusCodeNotOk,
httpErrorCodeMessage,
response,
null,
requestContext);
if (shouldLogAsError)
{
requestContext.Logger.ErrorPii(
string.Format(MsalErrorMessage.RequestFailureErrorMessagePii,
requestContext.ApiEvent?.ApiIdString,
requestContext.ServiceBundle.Config.Authority.AuthorityInfo.CanonicalAuthority,
requestContext.ServiceBundle.Config.ClientId),
string.Format(MsalErrorMessage.RequestFailureErrorMessage,
requestContext.ApiEvent?.ApiIdString,
requestContext.ServiceBundle.Config.Authority.AuthorityInfo.Host));
requestContext.Logger.ErrorPii(exceptionToThrow);
}
else
{
requestContext.Logger.InfoPii(exceptionToThrow);
}
throw exceptionToThrow;
}
private static MsalServiceException ExtractErrorsFromTheResponse(HttpResponse response, ref bool shouldLogAsError, RequestContext context = null)
{
// In cases where the end-point is not found (404) response.body will be empty.
if (string.IsNullOrWhiteSpace(response.Body))
{
return null;
}
MsalTokenResponse msalTokenResponse;
try
{
msalTokenResponse = JsonHelper.DeserializeFromJson<MsalTokenResponse>(response.Body);
}
catch (JsonException)
{
//Throttled responses for client credential flows do not have a parsable response.
if ((int)response.StatusCode == 429)
{
return MsalServiceExceptionFactory.FromThrottledAuthenticationResponse(response);
}
throw;
}
if (msalTokenResponse?.Error == null)
{
return null;
}
// For device code flow, AuthorizationPending can occur a lot while waiting
// for the user to auth via browser and this causes a lot of error noise in the logs.
// So suppress this particular case to an Info so we still see the data but don't
// log it as an error since it's expected behavior while waiting for the user.
if (string.Compare(msalTokenResponse.Error, OAuth2Error.AuthorizationPending,
StringComparison.OrdinalIgnoreCase) == 0)
{
shouldLogAsError = false;
}
return MsalServiceExceptionFactory.FromHttpResponse(
msalTokenResponse.Error,
msalTokenResponse.ErrorDescription,
response,
context: context);
}
private Uri AddExtraQueryParams(Uri endPoint)
{
var endpointUri = new UriBuilder(endPoint);
string extraQp = _queryParameters.ToQueryParameter();
endpointUri.AppendQueryParameters(extraQp);
return endpointUri.Uri;
}
private static void VerifyCorrelationIdHeaderInResponse(
IDictionary<string, string> headers,
RequestContext requestContext)
{
foreach (string responseHeaderKey in headers.Keys)
{
string trimmedKey = responseHeaderKey.Trim();
if (string.Compare(trimmedKey, OAuth2Header.CorrelationId, StringComparison.OrdinalIgnoreCase) == 0)
{
// Use the original key to safely access the dictionary value
if (headers.TryGetValue(responseHeaderKey, out string headerValue) && !string.IsNullOrEmpty(headerValue))
{
string correlationIdHeader = headerValue.Trim();
if (string.Compare(
correlationIdHeader,
requestContext.CorrelationId.ToString(),
StringComparison.OrdinalIgnoreCase) != 0)
{
requestContext.Logger.WarningPii(
string.Format(
CultureInfo.InvariantCulture,
"Returned correlation id '{0}' does not match the sent correlation id '{1}'",
correlationIdHeader,
requestContext.CorrelationId),
"Returned correlation id does not match the sent correlation id");
}
}
break;
}
}
}
}
}