Skip to content

Commit 1bb5ff6

Browse files
feat(mcp): expose auth and httpx_client_factory in SSE/StreamableHttp params (#2713)
1 parent 44bbcfe commit 1bb5ff6

File tree

2 files changed

+214
-23
lines changed

2 files changed

+214
-23
lines changed

src/agents/mcp/server.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,17 @@ class MCPServerSseParams(TypedDict):
919919
sse_read_timeout: NotRequired[float]
920920
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
921921

922+
auth: NotRequired[httpx.Auth | None]
923+
"""Optional httpx authentication handler (e.g. ``httpx.BasicAuth``, a custom
924+
``httpx.Auth`` subclass for OAuth token refresh, etc.). When provided, it is
925+
passed directly to the underlying ``httpx.AsyncClient`` used by the SSE transport.
926+
"""
927+
928+
httpx_client_factory: NotRequired[HttpClientFactory]
929+
"""Custom HTTP client factory for configuring httpx.AsyncClient behavior (e.g.
930+
to set custom SSL certificates, proxies, or other transport options).
931+
"""
932+
922933

923934
class MCPServerSse(_MCPServerWithClientSession):
924935
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
@@ -1000,12 +1011,17 @@ def create_streams(
10001011
self,
10011012
) -> AbstractAsyncContextManager[MCPStreamTransport]:
10021013
"""Create the streams for the server."""
1003-
return sse_client(
1004-
url=self.params["url"],
1005-
headers=self.params.get("headers", None),
1006-
timeout=self.params.get("timeout", 5),
1007-
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
1008-
)
1014+
kwargs: dict[str, Any] = {
1015+
"url": self.params["url"],
1016+
"headers": self.params.get("headers", None),
1017+
"timeout": self.params.get("timeout", 5),
1018+
"sse_read_timeout": self.params.get("sse_read_timeout", 60 * 5),
1019+
}
1020+
if "auth" in self.params:
1021+
kwargs["auth"] = self.params["auth"]
1022+
if "httpx_client_factory" in self.params:
1023+
kwargs["httpx_client_factory"] = self.params["httpx_client_factory"]
1024+
return sse_client(**kwargs)
10091025

10101026
@property
10111027
def name(self) -> str:
@@ -1034,6 +1050,13 @@ class MCPServerStreamableHttpParams(TypedDict):
10341050
httpx_client_factory: NotRequired[HttpClientFactory]
10351051
"""Custom HTTP client factory for configuring httpx.AsyncClient behavior."""
10361052

1053+
auth: NotRequired[httpx.Auth | None]
1054+
"""Optional httpx authentication handler (e.g. ``httpx.BasicAuth``, a custom
1055+
``httpx.Auth`` subclass for OAuth token refresh, etc.). When provided, it is
1056+
passed directly to the underlying ``httpx.AsyncClient`` used by the Streamable HTTP
1057+
transport.
1058+
"""
1059+
10371060

10381061
class MCPServerStreamableHttp(_MCPServerWithClientSession):
10391062
"""MCP server implementation that uses the Streamable HTTP transport. See the [spec]
@@ -1117,24 +1140,18 @@ def create_streams(
11171140
self,
11181141
) -> AbstractAsyncContextManager[MCPStreamTransport]:
11191142
"""Create the streams for the server."""
1120-
# Only pass httpx_client_factory if it's provided
1143+
kwargs: dict[str, Any] = {
1144+
"url": self.params["url"],
1145+
"headers": self.params.get("headers", None),
1146+
"timeout": self.params.get("timeout", 5),
1147+
"sse_read_timeout": self.params.get("sse_read_timeout", 60 * 5),
1148+
"terminate_on_close": self.params.get("terminate_on_close", True),
1149+
}
11211150
if "httpx_client_factory" in self.params:
1122-
return streamablehttp_client(
1123-
url=self.params["url"],
1124-
headers=self.params.get("headers", None),
1125-
timeout=self.params.get("timeout", 5),
1126-
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
1127-
terminate_on_close=self.params.get("terminate_on_close", True),
1128-
httpx_client_factory=self.params["httpx_client_factory"],
1129-
)
1130-
else:
1131-
return streamablehttp_client(
1132-
url=self.params["url"],
1133-
headers=self.params.get("headers", None),
1134-
timeout=self.params.get("timeout", 5),
1135-
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
1136-
terminate_on_close=self.params.get("terminate_on_close", True),
1137-
)
1151+
kwargs["httpx_client_factory"] = self.params["httpx_client_factory"]
1152+
if "auth" in self.params:
1153+
kwargs["auth"] = self.params["auth"]
1154+
return streamablehttp_client(**kwargs)
11381155

11391156
@asynccontextmanager
11401157
async def _isolated_client_session(self):

tests/mcp/test_mcp_auth_params.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Tests for auth and httpx_client_factory params on MCPServerSse and MCPServerStreamableHttp."""
2+
3+
from __future__ import annotations
4+
5+
from unittest.mock import MagicMock, patch
6+
7+
import httpx
8+
import pytest
9+
10+
from agents.mcp import MCPServerSse, MCPServerStreamableHttp
11+
12+
13+
class TestMCPServerSseAuthAndFactory:
14+
"""Tests for auth and httpx_client_factory added to MCPServerSseParams."""
15+
16+
@pytest.mark.asyncio
17+
async def test_sse_default_no_auth_no_factory(self):
18+
"""SSE create_streams passes only the four base params when no extras are set."""
19+
with patch("agents.mcp.server.sse_client") as mock_client:
20+
mock_client.return_value = MagicMock()
21+
server = MCPServerSse(params={"url": "http://localhost:8000/sse"})
22+
server.create_streams()
23+
mock_client.assert_called_once_with(
24+
url="http://localhost:8000/sse",
25+
headers=None,
26+
timeout=5,
27+
sse_read_timeout=300,
28+
)
29+
30+
@pytest.mark.asyncio
31+
async def test_sse_with_auth(self):
32+
"""SSE create_streams forwards the auth parameter when provided."""
33+
auth = httpx.BasicAuth(username="user", password="pass")
34+
with patch("agents.mcp.server.sse_client") as mock_client:
35+
mock_client.return_value = MagicMock()
36+
server = MCPServerSse(params={"url": "http://localhost:8000/sse", "auth": auth})
37+
server.create_streams()
38+
mock_client.assert_called_once_with(
39+
url="http://localhost:8000/sse",
40+
headers=None,
41+
timeout=5,
42+
sse_read_timeout=300,
43+
auth=auth,
44+
)
45+
46+
@pytest.mark.asyncio
47+
async def test_sse_with_httpx_client_factory(self):
48+
"""SSE create_streams forwards a custom httpx_client_factory when provided."""
49+
50+
def custom_factory(
51+
headers: dict[str, str] | None = None,
52+
timeout: httpx.Timeout | None = None,
53+
auth: httpx.Auth | None = None,
54+
) -> httpx.AsyncClient:
55+
return httpx.AsyncClient(verify=False) # pragma: no cover
56+
57+
with patch("agents.mcp.server.sse_client") as mock_client:
58+
mock_client.return_value = MagicMock()
59+
server = MCPServerSse(
60+
params={
61+
"url": "http://localhost:8000/sse",
62+
"httpx_client_factory": custom_factory,
63+
}
64+
)
65+
server.create_streams()
66+
mock_client.assert_called_once_with(
67+
url="http://localhost:8000/sse",
68+
headers=None,
69+
timeout=5,
70+
sse_read_timeout=300,
71+
httpx_client_factory=custom_factory,
72+
)
73+
74+
@pytest.mark.asyncio
75+
async def test_sse_with_auth_and_factory(self):
76+
"""SSE create_streams forwards both auth and httpx_client_factory together."""
77+
auth = httpx.BasicAuth(username="user", password="pass")
78+
79+
def custom_factory(
80+
headers: dict[str, str] | None = None,
81+
timeout: httpx.Timeout | None = None,
82+
auth: httpx.Auth | None = None,
83+
) -> httpx.AsyncClient:
84+
return httpx.AsyncClient(verify=False) # pragma: no cover
85+
86+
with patch("agents.mcp.server.sse_client") as mock_client:
87+
mock_client.return_value = MagicMock()
88+
server = MCPServerSse(
89+
params={
90+
"url": "http://localhost:8000/sse",
91+
"headers": {"X-Token": "abc"},
92+
"auth": auth,
93+
"httpx_client_factory": custom_factory,
94+
}
95+
)
96+
server.create_streams()
97+
mock_client.assert_called_once_with(
98+
url="http://localhost:8000/sse",
99+
headers={"X-Token": "abc"},
100+
timeout=5,
101+
sse_read_timeout=300,
102+
auth=auth,
103+
httpx_client_factory=custom_factory,
104+
)
105+
106+
107+
class TestMCPServerStreamableHttpAuth:
108+
"""Tests for the auth parameter added to MCPServerStreamableHttpParams."""
109+
110+
@pytest.mark.asyncio
111+
async def test_streamable_http_default_no_auth(self):
112+
"""StreamableHttp create_streams omits auth when not provided."""
113+
with patch("agents.mcp.server.streamablehttp_client") as mock_client:
114+
mock_client.return_value = MagicMock()
115+
server = MCPServerStreamableHttp(params={"url": "http://localhost:8000/mcp"})
116+
server.create_streams()
117+
mock_client.assert_called_once_with(
118+
url="http://localhost:8000/mcp",
119+
headers=None,
120+
timeout=5,
121+
sse_read_timeout=300,
122+
terminate_on_close=True,
123+
)
124+
125+
@pytest.mark.asyncio
126+
async def test_streamable_http_with_auth(self):
127+
"""StreamableHttp create_streams forwards the auth parameter when provided."""
128+
auth = httpx.BasicAuth(username="user", password="pass")
129+
with patch("agents.mcp.server.streamablehttp_client") as mock_client:
130+
mock_client.return_value = MagicMock()
131+
server = MCPServerStreamableHttp(
132+
params={"url": "http://localhost:8000/mcp", "auth": auth}
133+
)
134+
server.create_streams()
135+
mock_client.assert_called_once_with(
136+
url="http://localhost:8000/mcp",
137+
headers=None,
138+
timeout=5,
139+
sse_read_timeout=300,
140+
terminate_on_close=True,
141+
auth=auth,
142+
)
143+
144+
@pytest.mark.asyncio
145+
async def test_streamable_http_with_auth_and_factory(self):
146+
"""StreamableHttp create_streams forwards both auth and httpx_client_factory."""
147+
auth = httpx.BasicAuth(username="user", password="pass")
148+
149+
def custom_factory(
150+
headers: dict[str, str] | None = None,
151+
timeout: httpx.Timeout | None = None,
152+
auth: httpx.Auth | None = None,
153+
) -> httpx.AsyncClient:
154+
return httpx.AsyncClient(verify=False) # pragma: no cover
155+
156+
with patch("agents.mcp.server.streamablehttp_client") as mock_client:
157+
mock_client.return_value = MagicMock()
158+
server = MCPServerStreamableHttp(
159+
params={
160+
"url": "http://localhost:8000/mcp",
161+
"auth": auth,
162+
"httpx_client_factory": custom_factory,
163+
}
164+
)
165+
server.create_streams()
166+
mock_client.assert_called_once_with(
167+
url="http://localhost:8000/mcp",
168+
headers=None,
169+
timeout=5,
170+
sse_read_timeout=300,
171+
terminate_on_close=True,
172+
auth=auth,
173+
httpx_client_factory=custom_factory,
174+
)

0 commit comments

Comments
 (0)