Skip to content

Commit 6758180

Browse files
committed
fix(oauth): preserve existing refresh_token when refresh response omits it (#2270)
1 parent 62eb08e commit 6758180

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,18 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p
458458
content = await response.aread()
459459
token_response = OAuthToken.model_validate_json(content)
460460

461+
# Per RFC 6749 Section 6, the authorization server MAY issue a new
462+
# refresh token. If the response omits one, preserve the existing
463+
# refresh token so subsequent refresh attempts remain possible.
464+
if (
465+
not token_response.refresh_token
466+
and self.context.current_tokens
467+
and self.context.current_tokens.refresh_token
468+
):
469+
token_response = token_response.model_copy(
470+
update={"refresh_token": self.context.current_tokens.refresh_token}
471+
)
472+
461473
self.context.current_tokens = token_response
462474
self.context.update_token_expiry(token_response)
463475
await self.context.storage.set_tokens(token_response)

tests/client/test_auth.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,75 @@ async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvide
711711
content = request.content.decode()
712712
assert "client_secret=" not in content
713713

714+
@pytest.mark.anyio
715+
async def test_handle_refresh_response_preserves_existing_refresh_token(
716+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
717+
):
718+
"""Test that the existing refresh_token is preserved when the server omits it.
719+
720+
Per RFC 6749 Section 6, the authorization server MAY issue a new refresh
721+
token in the refresh response. If it doesn't, the client should continue
722+
using the existing one.
723+
"""
724+
oauth_provider.context.current_tokens = valid_tokens
725+
726+
# Server response without refresh_token
727+
refresh_response = httpx.Response(
728+
200,
729+
content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}',
730+
request=httpx.Request("POST", "https://auth.example.com/token"),
731+
)
732+
733+
result = await oauth_provider._handle_refresh_response(refresh_response)
734+
735+
assert result is True
736+
assert oauth_provider.context.current_tokens is not None
737+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
738+
# Old refresh_token should be preserved
739+
assert oauth_provider.context.current_tokens.refresh_token == "test_refresh_token"
740+
741+
@pytest.mark.anyio
742+
async def test_handle_refresh_response_uses_new_refresh_token(
743+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
744+
):
745+
"""Test that a new refresh_token from the server replaces the old one."""
746+
oauth_provider.context.current_tokens = valid_tokens
747+
748+
# Server response with a new refresh_token (token rotation)
749+
refresh_response = httpx.Response(
750+
200,
751+
content=(
752+
b'{"access_token": "new_access_token", "token_type": "Bearer",'
753+
b' "expires_in": 3600, "refresh_token": "rotated_refresh_token"}'
754+
),
755+
request=httpx.Request("POST", "https://auth.example.com/token"),
756+
)
757+
758+
result = await oauth_provider._handle_refresh_response(refresh_response)
759+
760+
assert result is True
761+
assert oauth_provider.context.current_tokens is not None
762+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
763+
assert oauth_provider.context.current_tokens.refresh_token == "rotated_refresh_token"
764+
765+
@pytest.mark.anyio
766+
async def test_handle_refresh_response_no_prior_tokens(self, oauth_provider: OAuthClientProvider):
767+
"""Test refresh response when there are no prior tokens stored."""
768+
oauth_provider.context.current_tokens = None
769+
770+
refresh_response = httpx.Response(
771+
200,
772+
content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}',
773+
request=httpx.Request("POST", "https://auth.example.com/token"),
774+
)
775+
776+
result = await oauth_provider._handle_refresh_response(refresh_response)
777+
778+
assert result is True
779+
assert oauth_provider.context.current_tokens is not None
780+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
781+
assert oauth_provider.context.current_tokens.refresh_token is None
782+
714783
@pytest.mark.anyio
715784
async def test_none_auth_method(self, oauth_provider: OAuthClientProvider):
716785
"""Test 'none' authentication method (public client)."""

0 commit comments

Comments
 (0)