Skip to content

Commit ca0cc75

Browse files
committed
feat(tools): expose httpx_client_factory on RestApiTool and OpenAPIToolset
Mirror the pattern merged for MCP in #2997 (StreamableHTTPConnectionParams. httpx_client_factory) on the OpenAPI tool surface. Adds an optional httpx_client_factory parameter to: - RestApiTool.__init__ - RestApiTool.from_parsed_operation - OpenAPIToolset.__init__ OpenAPIToolset forwards the factory to every generated RestApiTool the same way ssl_verify and header_provider are already forwarded. When provided, the factory's client is used to issue each API call; when None (default), the existing httpx.AsyncClient(verify=..., timeout=None) construction is preserved exactly. This unlocks httpx.AsyncClient knobs that the narrower ssl_verify parameter can't reach: proxies, HTTP/2, custom transports (e.g. request signing), and shared connection pools. Closes #5681
1 parent 4309159 commit ca0cc75

4 files changed

Lines changed: 182 additions & 6 deletions

File tree

src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ...base_toolset import BaseToolset
3737
from ...base_toolset import ToolPredicate
3838
from .openapi_spec_parser import OpenApiSpecParser
39+
from .rest_api_tool import HttpxClientFactory
3940
from .rest_api_tool import RestApiTool
4041

4142
logger = logging.getLogger("google_adk." + __name__)
@@ -77,6 +78,7 @@ def __init__(
7778
header_provider: Optional[
7879
Callable[[ReadonlyContext], Dict[str, str]]
7980
] = None,
81+
httpx_client_factory: Optional[HttpxClientFactory] = None,
8082
preserve_property_names: bool = False,
8183
):
8284
"""Initializes the OpenAPIToolset.
@@ -130,6 +132,14 @@ def __init__(
130132
an argument, allowing dynamic header generation based on the current
131133
context. Useful for adding custom headers like correlation IDs,
132134
authentication tokens, or other request metadata.
135+
httpx_client_factory: Optional zero-argument callable returning an
136+
``httpx.AsyncClient`` to use for every generated tool's API calls.
137+
When provided, it takes precedence over the per-tool default client
138+
construction and unlocks ``httpx.AsyncClient`` options that
139+
``ssl_verify`` can't reach (proxies, HTTP/2, custom transports such as
140+
request signing, shared connection pools). Defaults to ``None``, which
141+
preserves today's behaviour. Mirrors the pattern exposed for MCP by
142+
``StreamableHTTPConnectionParams.httpx_client_factory``.
133143
preserve_property_names: If True, preserve the original property names
134144
from the OpenAPI spec instead of converting them to snake_case. This
135145
is useful when calling APIs that expect camelCase or other
@@ -155,6 +165,7 @@ def __init__(
155165
if not spec_dict:
156166
spec_dict = self._load_spec(spec_str, spec_str_type)
157167
self._ssl_verify = ssl_verify
168+
self._httpx_client_factory = httpx_client_factory
158169
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
159170
if auth_scheme or auth_credential:
160171
self._configure_auth_all(auth_scheme, auth_credential)
@@ -237,6 +248,7 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
237248
o,
238249
ssl_verify=self._ssl_verify,
239250
header_provider=self._header_provider,
251+
httpx_client_factory=self._httpx_client_factory,
240252
)
241253
logger.info("Parsed tool: %s", tool.name)
242254
tools.append(tool)

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ def snake_to_lower_camel(snake_case_string: str):
7575

7676
AuthPreparationState = Literal["pending", "done"]
7777

78+
HttpxClientFactory = Callable[..., httpx.AsyncClient]
79+
"""Type alias for a factory returning an ``httpx.AsyncClient``.
80+
81+
When supplied to ``RestApiTool`` or ``OpenAPIToolset``, the factory is invoked
82+
once per API call and its returned client is used (as an async context
83+
manager) to issue the request, in place of the default
84+
``httpx.AsyncClient(verify=..., timeout=None)``. This unlocks knobs that the
85+
narrower ``ssl_verify`` parameter can't reach: proxies, HTTP/2, custom
86+
transports (e.g. request-signing), shared connection pools, and so on.
87+
"""
88+
7889

7990
class RestApiTool(BaseTool):
8091
"""A generic tool that interacts with a REST API.
@@ -103,6 +114,7 @@ def __init__(
103114
header_provider: Optional[
104115
Callable[[ReadonlyContext], Dict[str, str]]
105116
] = None,
117+
httpx_client_factory: Optional[HttpxClientFactory] = None,
106118
*,
107119
credential_key: Optional[str] = None,
108120
):
@@ -142,6 +154,15 @@ def __init__(
142154
an argument, allowing dynamic header generation based on the current
143155
context. Useful for adding custom headers like correlation IDs,
144156
authentication tokens, or other request metadata.
157+
httpx_client_factory: Optional zero-argument callable returning an
158+
``httpx.AsyncClient``. When provided, the returned client is used to
159+
issue the request, allowing callers to configure proxies, HTTP/2,
160+
custom transports (e.g. request signing), shared connection pools,
161+
or any other ``httpx.AsyncClient`` option that ``ssl_verify`` can't
162+
reach. When ``None`` (default), behaviour is unchanged: a fresh
163+
``httpx.AsyncClient(verify=..., timeout=None)`` is created per
164+
request. Mirrors the pattern exposed for MCP by
165+
``StreamableHTTPConnectionParams.httpx_client_factory``.
145166
credential_key: Optional stable key used for interactive auth and
146167
credential caching.
147168
"""
@@ -169,6 +190,7 @@ def __init__(
169190
self._default_headers: Dict[str, str] = {}
170191
self._ssl_verify = ssl_verify
171192
self._header_provider = header_provider
193+
self._httpx_client_factory = httpx_client_factory
172194
self._logger = logger
173195
if should_parse_operation:
174196
self._operation_parser = OperationParser(self.operation)
@@ -181,6 +203,7 @@ def from_parsed_operation(
181203
header_provider: Optional[
182204
Callable[[ReadonlyContext], Dict[str, str]]
183205
] = None,
206+
httpx_client_factory: Optional[HttpxClientFactory] = None,
184207
) -> "RestApiTool":
185208
"""Initializes the RestApiTool from a ParsedOperation object.
186209
@@ -192,6 +215,9 @@ def from_parsed_operation(
192215
an argument, allowing dynamic header generation based on the current
193216
context. Useful for adding custom headers like correlation IDs,
194217
authentication tokens, or other request metadata.
218+
httpx_client_factory: Optional zero-argument callable returning an
219+
``httpx.AsyncClient`` to be used for the API call. See
220+
``RestApiTool.__init__`` for details.
195221
196222
Returns:
197223
A RestApiTool object.
@@ -212,6 +238,7 @@ def from_parsed_operation(
212238
auth_credential=parsed.auth_credential,
213239
ssl_verify=ssl_verify,
214240
header_provider=header_provider,
241+
httpx_client_factory=httpx_client_factory,
215242
)
216243
generated._operation_parser = operation_parser
217244
return generated
@@ -520,7 +547,9 @@ async def call(
520547
if provider_headers:
521548
request_params.setdefault("headers", {}).update(provider_headers)
522549

523-
response = await _request(**request_params)
550+
response = await _request(
551+
httpx_client_factory=self._httpx_client_factory, **request_params
552+
)
524553

525554
# Log the API response
526555
self._logger.debug(
@@ -569,9 +598,14 @@ def __repr__(self):
569598
)
570599

571600

572-
async def _request(**request_params) -> httpx.Response:
573-
async with httpx.AsyncClient(
574-
verify=request_params.pop("verify", True),
575-
timeout=None,
576-
) as client:
601+
async def _request(
602+
*,
603+
httpx_client_factory: Optional[HttpxClientFactory] = None,
604+
**request_params,
605+
) -> httpx.Response:
606+
verify = request_params.pop("verify", True)
607+
if httpx_client_factory is not None:
608+
async with httpx_client_factory() as client:
609+
return await client.request(**request_params)
610+
async with httpx.AsyncClient(verify=verify, timeout=None) as client:
577611
return await client.request(**request_params)

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,29 @@ def test_openapi_toolset_verify_on_init(
153153
assert all(tool._ssl_verify == verify_value for tool in toolset._tools)
154154

155155

156+
def test_openapi_toolset_httpx_client_factory_on_init(
157+
openapi_spec: Dict[str, Any],
158+
):
159+
"""The httpx_client_factory is forwarded to every generated tool."""
160+
custom_factory = lambda: None # noqa: E731 - placeholder, never invoked here
161+
toolset = OpenAPIToolset(
162+
spec_dict=openapi_spec, httpx_client_factory=custom_factory
163+
)
164+
assert toolset._httpx_client_factory is custom_factory
165+
assert all(
166+
tool._httpx_client_factory is custom_factory for tool in toolset._tools
167+
)
168+
169+
170+
def test_openapi_toolset_httpx_client_factory_none_by_default(
171+
openapi_spec: Dict[str, Any],
172+
):
173+
"""httpx_client_factory is None on the toolset and each tool by default."""
174+
toolset = OpenAPIToolset(spec_dict=openapi_spec)
175+
assert toolset._httpx_client_factory is None
176+
assert all(tool._httpx_client_factory is None for tool in toolset._tools)
177+
178+
156179
def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]):
157180
"""Test configure_verify_all method."""
158181
toolset = OpenAPIToolset(spec_dict=openapi_spec)

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,113 @@ async def test_call_without_header_provider(
13391339

13401340
assert result == {"result": "success"}
13411341

1342+
def test_init_httpx_client_factory_none_by_default(
1343+
self,
1344+
sample_endpoint,
1345+
sample_operation,
1346+
):
1347+
"""httpx_client_factory is None by default."""
1348+
tool = RestApiTool(
1349+
name="test_tool",
1350+
description="Test Tool",
1351+
endpoint=sample_endpoint,
1352+
operation=sample_operation,
1353+
)
1354+
assert tool._httpx_client_factory is None
1355+
1356+
def test_init_with_httpx_client_factory(
1357+
self,
1358+
sample_endpoint,
1359+
sample_operation,
1360+
):
1361+
"""A user-supplied httpx_client_factory is stored on the tool."""
1362+
custom_factory = MagicMock()
1363+
tool = RestApiTool(
1364+
name="test_tool",
1365+
description="Test Tool",
1366+
endpoint=sample_endpoint,
1367+
operation=sample_operation,
1368+
httpx_client_factory=custom_factory,
1369+
)
1370+
assert tool._httpx_client_factory is custom_factory
1371+
1372+
@pytest.mark.asyncio
1373+
async def test_call_uses_custom_httpx_client_factory(
1374+
self,
1375+
mock_tool_context,
1376+
sample_endpoint,
1377+
sample_operation,
1378+
sample_auth_scheme,
1379+
sample_auth_credential,
1380+
):
1381+
"""When a factory is provided, its client is used to issue the request."""
1382+
mock_response = mock.create_autospec(requests.Response, instance=True)
1383+
mock_response.json.return_value = {"result": "success"}
1384+
mock_response.configure_mock(status_code=200)
1385+
1386+
mock_client = mock.create_autospec(
1387+
httpx.AsyncClient, instance=True, spec_set=True
1388+
)
1389+
mock_client.request = AsyncMock(return_value=mock_response)
1390+
# Make the mock client work as an async context manager.
1391+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
1392+
mock_client.__aexit__ = AsyncMock(return_value=None)
1393+
1394+
custom_factory = MagicMock(return_value=mock_client)
1395+
1396+
tool = RestApiTool(
1397+
name="test_tool",
1398+
description="Test Tool",
1399+
endpoint=sample_endpoint,
1400+
operation=sample_operation,
1401+
auth_scheme=sample_auth_scheme,
1402+
auth_credential=sample_auth_credential,
1403+
httpx_client_factory=custom_factory,
1404+
)
1405+
1406+
with patch.object(httpx, "AsyncClient", autospec=True) as mock_default:
1407+
result = await tool.call(args={}, tool_context=mock_tool_context)
1408+
1409+
# Factory must be invoked once and the default client must not be built.
1410+
custom_factory.assert_called_once_with()
1411+
mock_default.assert_not_called()
1412+
mock_client.request.assert_awaited_once()
1413+
assert result == {"result": "success"}
1414+
1415+
@pytest.mark.asyncio
1416+
async def test_call_without_httpx_client_factory_uses_default_client(
1417+
self,
1418+
mock_tool_context,
1419+
sample_endpoint,
1420+
sample_operation,
1421+
sample_auth_scheme,
1422+
sample_auth_credential,
1423+
):
1424+
"""When no factory is provided, the default httpx.AsyncClient is used."""
1425+
mock_response = mock.create_autospec(requests.Response, instance=True)
1426+
mock_response.json.return_value = {"result": "success"}
1427+
mock_response.configure_mock(status_code=200)
1428+
1429+
mock_client = mock.create_autospec(
1430+
httpx.AsyncClient, instance=True, spec_set=True
1431+
)
1432+
mock_client.request = AsyncMock(return_value=mock_response)
1433+
1434+
tool = RestApiTool(
1435+
name="test_tool",
1436+
description="Test Tool",
1437+
endpoint=sample_endpoint,
1438+
operation=sample_operation,
1439+
auth_scheme=sample_auth_scheme,
1440+
auth_credential=sample_auth_credential,
1441+
)
1442+
1443+
with patch.object(
1444+
httpx, "AsyncClient", return_value=mock_client, autospec=True
1445+
) as mock_async_client:
1446+
await tool.call(args={}, tool_context=mock_tool_context)
1447+
assert mock_async_client.called
1448+
13421449
def test_prepare_request_params_extracts_embedded_query_params(
13431450
self, sample_auth_credential, sample_auth_scheme
13441451
):

0 commit comments

Comments
 (0)