Skip to content

Commit 11c135a

Browse files
committed
fix(client): preserve existing query params on OAuth authorization_endpoint
Closes #2776 The authorization code grant built the redirect URL with `f"{auth_endpoint}?{urlencode(auth_params)}"`, which produces an invalid URL when the server-advertised authorization_endpoint already carries a query string. For example Salesforce advertises `.../services/oauth2/authorize?prompt=select_account`, yielding `...authorize?prompt=select_account?response_type=code&...` (two `?` separators), so the client navigates to a malformed URL and the server rejects the request. Fix: parse the endpoint, merge its existing query params with the flow-generated auth_params (flow params win on conflict), and re-encode into a single well-formed query string. None-valued params are dropped rather than serialized as the literal "None". Tests: add TestAuthorizationEndpointWithQuery covering the helper (no/with/conflicting existing query) plus an end-to-end _perform_authorization_code_grant assertion that the captured redirect URL preserves the server param and stays well-formed. 101 passed.
1 parent 19fe9fa commit 11c135a

2 files changed

Lines changed: 103 additions & 3 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse
1616

1717
import anyio
1818
import httpx
@@ -53,6 +53,22 @@
5353
logger = logging.getLogger(__name__)
5454

5555

56+
def _build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str:
57+
"""Build an authorization URL, preserving any query params already on the endpoint.
58+
59+
Servers may advertise an ``authorization_endpoint`` that already carries query
60+
parameters (e.g. ``https://example.com/authorize?prompt=select_account``).
61+
Naively appending ``?<params>`` would produce an invalid URL with two ``?``
62+
separators, so the existing query is parsed and merged with ``auth_params``.
63+
Flow-generated params take precedence on key conflicts; ``None`` values are
64+
dropped rather than serialized as the literal string ``"None"``.
65+
"""
66+
parsed = urlparse(auth_endpoint)
67+
merged_params = dict(parse_qsl(parsed.query, keep_blank_values=True))
68+
merged_params.update({key: value for key, value in auth_params.items() if value is not None})
69+
return urlunparse(parsed._replace(query=urlencode(merged_params)))
70+
71+
5672
class PKCEParameters(BaseModel):
5773
"""PKCE (Proof Key for Code Exchange) parameters."""
5874

@@ -353,7 +369,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
353369
if "offline_access" in self.context.client_metadata.scope.split():
354370
auth_params["prompt"] = "consent"
355371

356-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
372+
authorization_url = _build_authorization_url(auth_endpoint, auth_params)
357373
await self.context.redirect_handler(authorization_url)
358374

359375
# Wait for callback

tests/client/test_auth.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from mcp.client.auth import OAuthClientProvider, PKCEParameters
1414
from mcp.client.auth.exceptions import OAuthFlowError
15+
from mcp.client.auth.oauth2 import _build_authorization_url
1516
from mcp.client.auth.utils import (
1617
build_oauth_authorization_server_metadata_discovery_urls,
1718
build_protected_resource_metadata_discovery_urls,
@@ -2618,3 +2619,86 @@ async def callback_handler() -> tuple[str, str | None]:
26182619
await auth_flow.asend(final_response)
26192620
except StopAsyncIteration:
26202621
pass
2622+
2623+
2624+
class TestAuthorizationEndpointWithQuery:
2625+
"""Regression tests for #2776 - authorization_endpoint carrying query params."""
2626+
2627+
def test_build_authorization_url_no_existing_query(self):
2628+
url = _build_authorization_url(
2629+
"https://auth.example.com/authorize",
2630+
{"response_type": "code", "client_id": "abc"},
2631+
)
2632+
parsed = urlparse(url)
2633+
params = parse_qs(parsed.query)
2634+
assert parsed.path == "/authorize"
2635+
assert params["response_type"] == ["code"]
2636+
assert params["client_id"] == ["abc"]
2637+
# No malformed double "?" separator.
2638+
assert url.count("?") == 1
2639+
2640+
def test_build_authorization_url_preserves_existing_query(self):
2641+
# e.g. Salesforce advertises .../authorize?prompt=select_account
2642+
url = _build_authorization_url(
2643+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account",
2644+
{"response_type": "code", "client_id": "abc"},
2645+
)
2646+
parsed = urlparse(url)
2647+
params = parse_qs(parsed.query)
2648+
assert parsed.path == "/services/oauth2/authorize"
2649+
# The server-provided param survives...
2650+
assert params["prompt"] == ["select_account"]
2651+
# ...alongside the flow-generated params.
2652+
assert params["response_type"] == ["code"]
2653+
assert params["client_id"] == ["abc"]
2654+
# Exactly one "?" - the old f-string produced "...?prompt=...?response_type=...".
2655+
assert url.count("?") == 1
2656+
2657+
def test_build_authorization_url_flow_params_win_on_conflict(self):
2658+
url = _build_authorization_url(
2659+
"https://auth.example.com/authorize?response_type=token",
2660+
{"response_type": "code"},
2661+
)
2662+
params = parse_qs(urlparse(url).query)
2663+
assert params["response_type"] == ["code"]
2664+
2665+
@pytest.mark.anyio
2666+
async def test_perform_authorization_preserves_endpoint_query(self, oauth_provider: OAuthClientProvider):
2667+
"""End-to-end: redirect URL stays valid when the endpoint has a query string."""
2668+
oauth_provider.context.oauth_metadata = OAuthMetadata(
2669+
issuer=AnyHttpUrl("https://test.salesforce.com"),
2670+
authorization_endpoint=AnyHttpUrl(
2671+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account"
2672+
),
2673+
token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"),
2674+
)
2675+
oauth_provider.context.client_info = OAuthClientInformationFull(
2676+
client_id="test_client_id",
2677+
client_secret="test_client_secret",
2678+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2679+
)
2680+
2681+
captured_url: str | None = None
2682+
captured_state: str | None = None
2683+
2684+
async def capture_redirect(url: str) -> None:
2685+
nonlocal captured_url, captured_state
2686+
captured_url = url
2687+
captured_state = parse_qs(urlparse(url).query).get("state", [None])[0]
2688+
2689+
async def mock_callback() -> tuple[str, str | None]:
2690+
return "test_auth_code", captured_state
2691+
2692+
oauth_provider.context.redirect_handler = capture_redirect
2693+
oauth_provider.context.callback_handler = mock_callback
2694+
2695+
await oauth_provider._perform_authorization_code_grant()
2696+
2697+
assert captured_url is not None
2698+
parsed = urlparse(captured_url)
2699+
params = parse_qs(parsed.query)
2700+
assert parsed.path == "/services/oauth2/authorize"
2701+
assert params["prompt"] == ["select_account"]
2702+
assert params["response_type"] == ["code"]
2703+
assert params["client_id"] == ["test_client_id"]
2704+
assert captured_url.count("?") == 1

0 commit comments

Comments
 (0)