Skip to content

Commit 8fc0af1

Browse files
committed
Update GenericOAuthProvider.cs
1 parent f806625 commit 8fc0af1

1 file changed

Lines changed: 153 additions & 51 deletions

File tree

src/ModelContextProtocol/Authentication/GenericOAuthProvider.cs

Lines changed: 153 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@ public class GenericOAuthProvider : IMcpCredentialProvider
2424
private readonly List<string> _scopes;
2525
private readonly string _clientId;
2626
private readonly string _clientSecret;
27-
private readonly HttpClient _httpClient;
28-
private readonly AuthorizationHelpers _authorizationHelpers;
27+
private readonly HttpClient _httpClient; private readonly AuthorizationHelpers _authorizationHelpers;
2928
private readonly ILogger _logger;
29+
private readonly Func<IReadOnlyList<Uri>, Uri?> _authServerSelector;
3030

3131
// Lazy-initialized shared HttpClient for when no client is provided
3232
private static readonly Lazy<HttpClient> _defaultHttpClient = new(() => new HttpClient());
3333

3434
private static readonly JsonSerializerOptions _jsonOptions = new() { PropertyNameCaseInsensitive = true };
35-
3635
private TokenContainer? _token;
37-
private AuthorizationServerMetadata? _authServerMetadata; /// <summary>
36+
private AuthorizationServerMetadata? _authServerMetadata;
37+
38+
/// <summary>
3839
/// Initializes a new instance of the <see cref="GenericOAuthProvider"/> class.
3940
/// </summary>
4041
/// <param name="serverUrl">The MCP server URL.</param>
@@ -53,18 +54,56 @@ public GenericOAuthProvider(
5354
string clientSecret = "",
5455
Uri? redirectUri = null,
5556
IEnumerable<string>? scopes = null,
56-
ILogger<GenericOAuthProvider>? logger = null)
57+
ILogger<GenericOAuthProvider>? logger = null) : this(serverUrl, httpClient, authorizationHelpers, clientId, clientSecret, redirectUri, scopes, logger, null)
58+
{
59+
} /// <summary>
60+
/// Initializes a new instance of the <see cref="GenericOAuthProvider"/> class with explicit authorization server selection.
61+
/// </summary>
62+
/// <param name="serverUrl">The MCP server URL.</param>
63+
/// <param name="httpClient">The HTTP client to use for OAuth requests. If null, a default HttpClient will be used.</param>
64+
/// <param name="authorizationHelpers">The authorization helpers.</param>
65+
/// <param name="clientId">OAuth client ID.</param>
66+
/// <param name="clientSecret">OAuth client secret.</param>
67+
/// <param name="redirectUri">OAuth redirect URI.</param>
68+
/// <param name="scopes">OAuth scopes.</param>
69+
/// <param name="logger">The logger instance. If null, a NullLogger will be used.</param>
70+
/// <param name="authServerSelector">Function to select which authorization server to use from available servers. If null, uses default selection strategy.</param>
71+
/// <exception cref="ArgumentNullException">Thrown when serverUrl is null.</exception>
72+
public GenericOAuthProvider(
73+
Uri serverUrl,
74+
HttpClient? httpClient,
75+
AuthorizationHelpers? authorizationHelpers,
76+
string clientId,
77+
string clientSecret,
78+
Uri? redirectUri,
79+
IEnumerable<string>? scopes,
80+
ILogger<GenericOAuthProvider>? logger,
81+
Func<IReadOnlyList<Uri>, Uri?>? authServerSelector)
5782
{
5883
if (serverUrl == null) throw new ArgumentNullException(nameof(serverUrl));
59-
_serverUrl = serverUrl;
84+
85+
_serverUrl = serverUrl;
6086
_httpClient = httpClient ?? _defaultHttpClient.Value;
6187
_authorizationHelpers = authorizationHelpers ?? new AuthorizationHelpers(_httpClient);
6288
_logger = (ILogger?)logger ?? NullLogger.Instance;
6389

6490
_redirectUri = redirectUri ?? new Uri("http://localhost:8080/callback");
6591
_scopes = scopes?.ToList() ?? [];
66-
_clientId = clientId;
67-
_clientSecret = clientSecret;
92+
_clientId = clientId ?? "demo-client";
93+
_clientSecret = clientSecret ?? "";
94+
95+
// Set up authorization server selection strategy
96+
_authServerSelector = authServerSelector ?? DefaultAuthServerSelector;
97+
}
98+
99+
/// <summary>
100+
/// Default authorization server selection strategy that selects the first available server.
101+
/// </summary>
102+
/// <param name="availableServers">List of available authorization servers.</param>
103+
/// <returns>The selected authorization server, or null if none are available.</returns>
104+
private static Uri? DefaultAuthServerSelector(IReadOnlyList<Uri> availableServers)
105+
{
106+
return availableServers.FirstOrDefault();
68107
}
69108

70109
/// <inheritdoc />
@@ -81,7 +120,6 @@ public GenericOAuthProvider(
81120

82121
return GetBearerTokenAsync(cancellationToken);
83122
} /// <inheritdoc />
84-
85123
public async Task<McpUnauthorizedResponseResult> HandleUnauthorizedResponseAsync(
86124
HttpResponseMessage response,
87125
string scheme,
@@ -92,43 +130,98 @@ public async Task<McpUnauthorizedResponseResult> HandleUnauthorizedResponseAsync
92130
{
93131
return new McpUnauthorizedResponseResult(false, null);
94132
}
95-
try
133+
try
96134
{
97-
// Get available authorization servers from the 401 response
98-
var availableAuthorizationServers = await _authorizationHelpers.GetAvailableAuthorizationServersAsync(
99-
response,
100-
_serverUrl,
101-
cancellationToken);
102-
103-
// Select the first available authorization server (or implement your own selection logic)
104-
var selectedAuthServer = availableAuthorizationServers.FirstOrDefault();
105-
106-
if (selectedAuthServer != null)
107-
{
108-
// Get auth server metadata
109-
var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken);
110-
111-
if (authServerMetadata != null)
112-
{
113-
// Store auth server metadata for future refresh operations
114-
_authServerMetadata = authServerMetadata;
115-
116-
// Do the OAuth flow
117-
var token = await InitiateAuthorizationCodeFlowAsync(authServerMetadata, cancellationToken);
118-
if (token != null)
119-
{
120-
_token = token;
121-
return new McpUnauthorizedResponseResult(true, BearerScheme);
122-
}
123-
}
124-
}
125-
126-
return new McpUnauthorizedResponseResult(false, null); }
135+
return await PerformOAuthAuthorizationAsync(response, cancellationToken);
136+
}
127137
catch (Exception ex)
128138
{
129-
_logger.LogError(ex, "Error handling auth challenge");
139+
_logger.LogError(ex, "Error handling OAuth authorization");
140+
return new McpUnauthorizedResponseResult(false, null);
141+
}
142+
}
143+
144+
/// <summary>
145+
/// Performs OAuth authorization by selecting an appropriate authorization server and completing the OAuth flow.
146+
/// </summary>
147+
/// <param name="response">The 401 Unauthorized response containing authentication challenge.</param>
148+
/// <param name="cancellationToken">Cancellation token.</param>
149+
/// <returns>Result indicating whether authorization was successful.</returns>
150+
private async Task<McpUnauthorizedResponseResult> PerformOAuthAuthorizationAsync(
151+
HttpResponseMessage response,
152+
CancellationToken cancellationToken)
153+
{
154+
// Get available authorization servers from the 401 response
155+
var availableAuthorizationServers = await _authorizationHelpers.GetAvailableAuthorizationServersAsync(
156+
response,
157+
_serverUrl,
158+
cancellationToken);
159+
160+
if (!availableAuthorizationServers.Any())
161+
{
162+
_logger.LogWarning("No authorization servers found in authentication challenge");
163+
return new McpUnauthorizedResponseResult(false, null);
164+
}
165+
166+
// Select authorization server using configured strategy
167+
var selectedAuthServer = SelectAuthorizationServer(availableAuthorizationServers);
168+
169+
if (selectedAuthServer == null)
170+
{
171+
_logger.LogWarning("Authorization server selection returned null. Available servers: {Servers}",
172+
string.Join(", ", availableAuthorizationServers));
173+
return new McpUnauthorizedResponseResult(false, null);
174+
}
175+
176+
_logger.LogInformation("Selected authorization server: {Server} from {Count} available servers",
177+
selectedAuthServer, availableAuthorizationServers.Count);
178+
179+
// Get auth server metadata
180+
var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken);
181+
182+
if (authServerMetadata == null)
183+
{
184+
_logger.LogError("Failed to retrieve metadata for authorization server: {Server}", selectedAuthServer);
130185
return new McpUnauthorizedResponseResult(false, null);
131186
}
187+
188+
// Store auth server metadata for future refresh operations
189+
_authServerMetadata = authServerMetadata;
190+
191+
// Perform the OAuth flow
192+
var token = await InitiateAuthorizationCodeFlowAsync(authServerMetadata, cancellationToken);
193+
if (token != null)
194+
{
195+
_token = token;
196+
_logger.LogInformation("OAuth authorization completed successfully");
197+
return new McpUnauthorizedResponseResult(true, BearerScheme);
198+
}
199+
200+
_logger.LogError("OAuth authorization flow failed");
201+
return new McpUnauthorizedResponseResult(false, null);
202+
} /// <summary>
203+
/// Selects an authorization server from the available options using the configured selection strategy.
204+
/// </summary>
205+
/// <param name="availableServers">List of available authorization servers.</param>
206+
/// <returns>Selected authorization server URI, or null if selection failed.</returns>
207+
private Uri? SelectAuthorizationServer(IReadOnlyList<Uri> availableServers)
208+
{
209+
if (!availableServers.Any())
210+
{
211+
return null;
212+
}
213+
214+
// Use the configured selection function
215+
var selected = _authServerSelector(availableServers);
216+
217+
if (selected != null && !availableServers.Contains(selected))
218+
{
219+
_logger.LogWarning("Authorization server selector returned a server not in the available list: {Selected}. " +
220+
"Available servers: {Available}", selected, string.Join(", ", availableServers));
221+
return null;
222+
}
223+
224+
return selected;
132225
}
133226

134227
private async Task<string?> GetBearerTokenAsync(CancellationToken cancellationToken = default)
@@ -178,7 +271,8 @@ public async Task<McpUnauthorizedResponseResult> HandleUnauthorizedResponseAsync
178271

179272
return metadata;
180273
}
181-
} }
274+
}
275+
}
182276
catch (Exception ex)
183277
{
184278
_logger.LogError(ex, "Error fetching auth server metadata from {Path}", path);
@@ -226,22 +320,25 @@ public async Task<McpUnauthorizedResponseResult> HandleUnauthorizedResponseAsync
226320

227321
return tokenResponse;
228322
}
229-
} }
323+
}
324+
}
230325
catch (Exception ex)
231326
{
232327
_logger.LogError(ex, "Error refreshing token");
233328
}
234329

235330
return null;
236-
} private async Task<TokenContainer?> InitiateAuthorizationCodeFlowAsync(
331+
}
332+
333+
private async Task<TokenContainer?> InitiateAuthorizationCodeFlowAsync(
237334
AuthorizationServerMetadata authServerMetadata,
238335
CancellationToken cancellationToken)
239336
{
240337
var codeVerifier = GenerateCodeVerifier();
241338
var codeChallenge = GenerateCodeChallenge(codeVerifier);
242339

243340
var authUrl = BuildAuthorizationUrl(authServerMetadata, codeChallenge);
244-
var authCode = await GetAuthorizationCodeAsync(authUrl, cancellationToken);
341+
var authCode = await GetAuthorizationCodeAsync(authUrl, cancellationToken);
245342
if (string.IsNullOrEmpty(authCode))
246343
return null;
247344

@@ -290,7 +387,7 @@ private Uri BuildAuthorizationUrl(AuthorizationServerMetadata authServerMetadata
290387
OpenBrowser(authorizationUrl);
291388

292389
var context = await listener.GetContextAsync();
293-
var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty);
390+
var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty);
294391
var code = query["code"];
295392
var error = query["error"];
296393

@@ -300,7 +397,8 @@ private Uri BuildAuthorizationUrl(AuthorizationServerMetadata authServerMetadata
300397
context.Response.ContentType = "text/html";
301398
context.Response.OutputStream.Write(buffer, 0, buffer.Length);
302399
context.Response.Close();
303-
if (!string.IsNullOrEmpty(error))
400+
401+
if (!string.IsNullOrEmpty(error))
304402
{
305403
_logger.LogError("Auth error: {Error}", error);
306404
return null;
@@ -312,7 +410,8 @@ private Uri BuildAuthorizationUrl(AuthorizationServerMetadata authServerMetadata
312410
return null;
313411
}
314412

315-
return code; }
413+
return code;
414+
}
316415
catch (Exception ex)
317416
{
318417
_logger.LogError(ex, "Error getting auth code");
@@ -362,13 +461,15 @@ private Uri BuildAuthorizationUrl(AuthorizationServerMetadata authServerMetadata
362461
{
363462
tokenResponse.ObtainedAt = DateTimeOffset.UtcNow;
364463
return tokenResponse;
365-
} }
464+
}
465+
}
366466
else
367467
{
368468
_logger.LogError("Token exchange failed: {StatusCode}", response.StatusCode);
369469
var error = await response.Content.ReadAsStringAsync(cancellationToken);
370470
_logger.LogError("Error: {Error}", error);
371-
} }
471+
}
472+
}
372473
catch (Exception ex)
373474
{
374475
_logger.LogError(ex, "Exception during token exchange");
@@ -386,7 +487,8 @@ private void OpenBrowser(Uri url)
386487
FileName = url.ToString(),
387488
UseShellExecute = true
388489
};
389-
System.Diagnostics.Process.Start(psi); }
490+
System.Diagnostics.Process.Start(psi);
491+
}
390492
catch (Exception ex)
391493
{
392494
_logger.LogError(ex, "Error opening browser");

0 commit comments

Comments
 (0)