Skip to content

Commit a059ef1

Browse files
committed
fix(client): dedupe concurrent oauth refreshes
1 parent e86b183 commit a059ef1

3 files changed

Lines changed: 152 additions & 11 deletions

File tree

.changeset/fresh-plums-cheer.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@modelcontextprotocol/client': patch
3+
---
4+
5+
Deduplicate concurrent OAuth refreshes for the same provider, authorization server, resource, and refresh token so parallel `auth()` callers reuse the in-flight refresh instead of replaying a rotating refresh token.

packages/client/src/client/auth.ts

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,46 @@ function isClientAuthMethod(method: string): method is ClientAuthMethod {
384384

385385
const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code';
386386
const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256';
387+
const inFlightRefreshAuthorizations = new WeakMap<OAuthClientProvider, Map<string, Promise<AuthResult>>>();
388+
389+
function buildRefreshAuthorizationKey(authorizationServerUrl: string | URL, refreshToken: string, resource?: URL): string {
390+
return JSON.stringify({
391+
authorizationServerUrl: String(authorizationServerUrl),
392+
refreshToken,
393+
resource: resource?.toString()
394+
});
395+
}
396+
397+
async function runWithRefreshAuthorizationLock(
398+
provider: OAuthClientProvider,
399+
key: string,
400+
runRefreshAuthorization: () => Promise<AuthResult>
401+
): Promise<AuthResult> {
402+
const existingRefreshAuthorization = inFlightRefreshAuthorizations.get(provider)?.get(key);
403+
if (existingRefreshAuthorization) {
404+
return await existingRefreshAuthorization;
405+
}
406+
407+
let refreshesByKey = inFlightRefreshAuthorizations.get(provider);
408+
if (!refreshesByKey) {
409+
refreshesByKey = new Map();
410+
inFlightRefreshAuthorizations.set(provider, refreshesByKey);
411+
}
412+
413+
const refreshAuthorizationPromise = runRefreshAuthorization();
414+
refreshesByKey.set(key, refreshAuthorizationPromise);
415+
try {
416+
return await refreshAuthorizationPromise;
417+
} finally {
418+
const currentRefreshesByKey = inFlightRefreshAuthorizations.get(provider);
419+
if (currentRefreshesByKey?.get(key) === refreshAuthorizationPromise) {
420+
currentRefreshesByKey.delete(key);
421+
if (currentRefreshesByKey.size === 0) {
422+
inFlightRefreshAuthorizations.delete(provider);
423+
}
424+
}
425+
}
426+
}
387427

388428
/**
389429
* Determines the best client authentication method to use based on server support and client configuration.
@@ -731,18 +771,21 @@ async function authInternal(
731771
// Handle token refresh or new authorization
732772
if (tokens?.refresh_token) {
733773
try {
734-
// Attempt to refresh the token
735-
const newTokens = await refreshAuthorization(authorizationServerUrl, {
736-
metadata,
737-
clientInformation,
738-
refreshToken: tokens.refresh_token,
739-
resource,
740-
addClientAuthentication: provider.addClientAuthentication,
741-
fetchFn
742-
});
774+
const refreshToken = tokens.refresh_token;
775+
const refreshAuthorizationKey = buildRefreshAuthorizationKey(authorizationServerUrl, refreshToken, resource);
776+
return await runWithRefreshAuthorizationLock(provider, refreshAuthorizationKey, async () => {
777+
const newTokens = await refreshAuthorization(authorizationServerUrl, {
778+
metadata,
779+
clientInformation,
780+
refreshToken,
781+
resource,
782+
addClientAuthentication: provider.addClientAuthentication,
783+
fetchFn
784+
});
743785

744-
await provider.saveTokens(newTokens);
745-
return 'AUTHORIZED';
786+
await provider.saveTokens(newTokens);
787+
return 'AUTHORIZED';
788+
});
746789
} catch (error) {
747790
// If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry.
748791
if (!(error instanceof OAuthError) || error.code === OAuthErrorCode.ServerError) {

packages/client/test/client/auth.test.ts

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,6 +2492,99 @@ describe('OAuth Authorization', () => {
24922492
expect(body.get('refresh_token')).toBe('refresh123');
24932493
});
24942494

2495+
it('deduplicates concurrent refreshes for the same provider and resource', async () => {
2496+
let releaseTokensReaders: (() => void) | undefined;
2497+
const tokensReady = new Promise<void>(resolve => {
2498+
releaseTokensReaders = resolve;
2499+
});
2500+
let tokensReaderCount = 0;
2501+
let resolveRefreshResponse: ((value: { ok: true; status: 200; json: () => Promise<OAuthTokens> }) => void) | undefined;
2502+
const refreshResponse = new Promise<{ ok: true; status: 200; json: () => Promise<OAuthTokens> }>(resolve => {
2503+
resolveRefreshResponse = resolve;
2504+
});
2505+
let refreshRequestCount = 0;
2506+
2507+
mockFetch.mockImplementation(url => {
2508+
const urlString = url.toString();
2509+
2510+
if (urlString.includes('/.well-known/oauth-protected-resource')) {
2511+
return Promise.resolve({
2512+
ok: true,
2513+
status: 200,
2514+
json: async () => ({
2515+
resource: 'https://api.example.com/mcp-server',
2516+
authorization_servers: ['https://auth.example.com']
2517+
})
2518+
});
2519+
}
2520+
2521+
if (urlString.includes('/.well-known/oauth-authorization-server')) {
2522+
return Promise.resolve({
2523+
ok: true,
2524+
status: 200,
2525+
json: async () => ({
2526+
issuer: 'https://auth.example.com',
2527+
authorization_endpoint: 'https://auth.example.com/authorize',
2528+
token_endpoint: 'https://auth.example.com/token',
2529+
response_types_supported: ['code'],
2530+
code_challenge_methods_supported: ['S256']
2531+
})
2532+
});
2533+
}
2534+
2535+
if (urlString.includes('/token')) {
2536+
refreshRequestCount++;
2537+
if (refreshRequestCount > 1) {
2538+
throw new Error('duplicate refresh request');
2539+
}
2540+
return refreshResponse;
2541+
}
2542+
2543+
return Promise.resolve({ ok: false, status: 404 });
2544+
});
2545+
2546+
(mockProvider.clientInformation as Mock).mockResolvedValue({
2547+
client_id: 'test-client',
2548+
client_secret: 'test-secret'
2549+
});
2550+
(mockProvider.tokens as Mock).mockImplementation(async () => {
2551+
tokensReaderCount++;
2552+
if (tokensReaderCount === 2) {
2553+
releaseTokensReaders?.();
2554+
}
2555+
await tokensReady;
2556+
return {
2557+
access_token: 'old-access',
2558+
refresh_token: 'refresh123'
2559+
};
2560+
});
2561+
(mockProvider.saveTokens as Mock).mockResolvedValue(undefined);
2562+
2563+
const authResults = Promise.all([
2564+
auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' }),
2565+
auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' })
2566+
]);
2567+
2568+
await vi.waitFor(() => {
2569+
expect(refreshRequestCount).toBe(1);
2570+
});
2571+
2572+
resolveRefreshResponse?.({
2573+
ok: true,
2574+
status: 200,
2575+
json: async () => ({
2576+
access_token: 'new-access123',
2577+
refresh_token: 'new-refresh456',
2578+
token_type: 'Bearer',
2579+
expires_in: 3600
2580+
})
2581+
});
2582+
2583+
await expect(authResults).resolves.toEqual(['AUTHORIZED', 'AUTHORIZED']);
2584+
expect(refreshRequestCount).toBe(1);
2585+
expect(mockProvider.saveTokens).toHaveBeenCalledTimes(1);
2586+
});
2587+
24952588
it('skips default PRM resource validation when custom validateResourceURL is provided', async () => {
24962589
const mockValidateResourceURL = vi.fn().mockResolvedValue(undefined);
24972590
const providerWithCustomValidation = {

0 commit comments

Comments
 (0)