Skip to content

Commit 7eb9b3d

Browse files
darshil3011xuanyang15
authored andcommitted
feat(tools): expose httpx_client_factory on RestApiTool and OpenAPIToolset
Merge #5715 Close #5681 ORIGINAL_AUTHOR=Darshil Modi <45987056+darshil3011@users.noreply.github.com> GitOrigin-RevId: 850900c Change-Id: I1dbe5929b907d5044ecfb8882569db3aa9ed666e
1 parent dc6e293 commit 7eb9b3d

4 files changed

Lines changed: 187 additions & 6 deletions

File tree

src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ...base_toolset import BaseToolset
3737
from ...base_toolset import ToolPredicate
3838
from .openapi_spec_parser import OpenApiSpecParser
39+
from .rest_api_tool import HttpxClientFactory
3940
from .rest_api_tool import RestApiTool
4041

4142
logger = logging.getLogger("google_adk." + __name__)
@@ -77,6 +78,7 @@ def __init__(
7778
header_provider: Optional[
7879
Callable[[ReadonlyContext], Dict[str, str]]
7980
] = None,
81+
httpx_client_factory: Optional[HttpxClientFactory] = None,
8082
preserve_property_names: bool = False,
8183
):
8284
"""Initializes the OpenAPIToolset.
@@ -130,6 +132,17 @@ def __init__(
130132
an argument, allowing dynamic header generation based on the current
131133
context. Useful for adding custom headers like correlation IDs,
132134
authentication tokens, or other request metadata.
135+
httpx_client_factory: Optional zero-argument callable returning an
136+
``httpx.AsyncClient`` to use for every generated tool's API calls. When
137+
provided, it takes precedence over the per-tool default client
138+
construction and unlocks ``httpx.AsyncClient`` options that
139+
``ssl_verify`` can't reach (proxies, HTTP/2, custom transports such as
140+
request signing). The returned client is used as an async context
141+
manager and closed after each request, so the factory must return a
142+
fresh client on every call. Defaults to ``None``, in which case each
143+
generated tool constructs its own ``httpx.AsyncClient`` per request.
144+
Mirrors the pattern exposed for MCP by
145+
``StreamableHTTPConnectionParams.httpx_client_factory``.
133146
preserve_property_names: If True, preserve the original property names
134147
from the OpenAPI spec instead of converting them to snake_case. This
135148
is useful when calling APIs that expect camelCase or other
@@ -155,6 +168,7 @@ def __init__(
155168
if not spec_dict:
156169
spec_dict = self._load_spec(spec_str, spec_str_type)
157170
self._ssl_verify = ssl_verify
171+
self._httpx_client_factory = httpx_client_factory
158172
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
159173
if auth_scheme or auth_credential:
160174
self._configure_auth_all(auth_scheme, auth_credential)
@@ -237,6 +251,7 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
237251
o,
238252
ssl_verify=self._ssl_verify,
239253
header_provider=self._header_provider,
254+
httpx_client_factory=self._httpx_client_factory,
240255
)
241256
logger.info("Parsed tool: %s", tool.name)
242257
tools.append(tool)

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def snake_to_lower_camel(snake_case_string: str):
7575

7676
AuthPreparationState = Literal["pending", "done"]
7777

78+
HttpxClientFactory = Callable[[], httpx.AsyncClient]
79+
"""Type alias for a zero-argument factory returning an ``httpx.AsyncClient``.
80+
81+
When supplied to ``RestApiTool`` or ``OpenAPIToolset``, the factory is invoked
82+
once per API call and its returned client is used as an async context manager
83+
to issue the request, in place of the default
84+
``httpx.AsyncClient(verify=..., timeout=None)``. Because the client is closed
85+
when the request completes, the factory must return a fresh client on every
86+
call. This unlocks knobs that the narrower ``ssl_verify`` parameter can't
87+
reach: proxies, HTTP/2, custom transports (e.g. request-signing), and so on.
88+
"""
89+
7890

7991
class RestApiTool(BaseTool):
8092
"""A generic tool that interacts with a REST API.
@@ -103,6 +115,7 @@ def __init__(
103115
header_provider: Optional[
104116
Callable[[ReadonlyContext], Dict[str, str]]
105117
] = None,
118+
httpx_client_factory: Optional[HttpxClientFactory] = None,
106119
*,
107120
credential_key: Optional[str] = None,
108121
):
@@ -142,6 +155,16 @@ def __init__(
142155
an argument, allowing dynamic header generation based on the current
143156
context. Useful for adding custom headers like correlation IDs,
144157
authentication tokens, or other request metadata.
158+
httpx_client_factory: Optional zero-argument callable returning an
159+
``httpx.AsyncClient``. When provided, the returned client is used as
160+
an async context manager to issue the request and is closed once the
161+
request completes, so the factory must return a fresh client on each
162+
call. This lets callers configure proxies, HTTP/2, custom transports
163+
(e.g. request signing), or any other ``httpx.AsyncClient`` option
164+
that ``ssl_verify`` can't reach. When ``None`` (default), a fresh
165+
``httpx.AsyncClient(verify=..., timeout=None)`` is created per
166+
request. Mirrors the pattern exposed for MCP by
167+
``StreamableHTTPConnectionParams.httpx_client_factory``.
145168
credential_key: Optional stable key used for interactive auth and
146169
credential caching.
147170
"""
@@ -169,6 +192,7 @@ def __init__(
169192
self._default_headers: Dict[str, str] = {}
170193
self._ssl_verify = ssl_verify
171194
self._header_provider = header_provider
195+
self._httpx_client_factory = httpx_client_factory
172196
self._logger = logger
173197
if should_parse_operation:
174198
self._operation_parser = OperationParser(self.operation)
@@ -181,6 +205,7 @@ def from_parsed_operation(
181205
header_provider: Optional[
182206
Callable[[ReadonlyContext], Dict[str, str]]
183207
] = None,
208+
httpx_client_factory: Optional[HttpxClientFactory] = None,
184209
) -> "RestApiTool":
185210
"""Initializes the RestApiTool from a ParsedOperation object.
186211
@@ -192,6 +217,9 @@ def from_parsed_operation(
192217
an argument, allowing dynamic header generation based on the current
193218
context. Useful for adding custom headers like correlation IDs,
194219
authentication tokens, or other request metadata.
220+
httpx_client_factory: Optional zero-argument callable returning an
221+
``httpx.AsyncClient`` to be used for the API call. See
222+
``RestApiTool.__init__`` for details.
195223
196224
Returns:
197225
A RestApiTool object.
@@ -212,6 +240,7 @@ def from_parsed_operation(
212240
auth_credential=parsed.auth_credential,
213241
ssl_verify=ssl_verify,
214242
header_provider=header_provider,
243+
httpx_client_factory=httpx_client_factory,
215244
)
216245
generated._operation_parser = operation_parser
217246
return generated
@@ -520,7 +549,9 @@ async def call(
520549
if provider_headers:
521550
request_params.setdefault("headers", {}).update(provider_headers)
522551

523-
response = await _request(**request_params)
552+
response = await _request(
553+
httpx_client_factory=self._httpx_client_factory, **request_params
554+
)
524555

525556
# Log the API response
526557
self._logger.debug(
@@ -575,9 +606,14 @@ def __repr__(self):
575606
)
576607

577608

578-
async def _request(**request_params) -> httpx.Response:
579-
async with httpx.AsyncClient(
580-
verify=request_params.pop("verify", True),
581-
timeout=None,
582-
) as client:
609+
async def _request(
610+
*,
611+
httpx_client_factory: Optional[HttpxClientFactory] = None,
612+
**request_params,
613+
) -> httpx.Response:
614+
verify = request_params.pop("verify", True)
615+
if httpx_client_factory is not None:
616+
async with httpx_client_factory() as client:
617+
return await client.request(**request_params)
618+
async with httpx.AsyncClient(verify=verify, timeout=None) as client:
583619
return await client.request(**request_params)

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,29 @@ def test_openapi_toolset_verify_on_init(
153153
assert all(tool._ssl_verify == verify_value for tool in toolset._tools)
154154

155155

156+
def test_openapi_toolset_httpx_client_factory_on_init(
157+
openapi_spec: Dict[str, Any],
158+
):
159+
"""The httpx_client_factory is forwarded to every generated tool."""
160+
custom_factory = lambda: None # noqa: E731 - placeholder, never invoked here
161+
toolset = OpenAPIToolset(
162+
spec_dict=openapi_spec, httpx_client_factory=custom_factory
163+
)
164+
assert toolset._httpx_client_factory is custom_factory
165+
assert all(
166+
tool._httpx_client_factory is custom_factory for tool in toolset._tools
167+
)
168+
169+
170+
def test_openapi_toolset_httpx_client_factory_none_by_default(
171+
openapi_spec: Dict[str, Any],
172+
):
173+
"""httpx_client_factory is None on the toolset and each tool by default."""
174+
toolset = OpenAPIToolset(spec_dict=openapi_spec)
175+
assert toolset._httpx_client_factory is None
176+
assert all(tool._httpx_client_factory is None for tool in toolset._tools)
177+
178+
156179
def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]):
157180
"""Test configure_verify_all method."""
158181
toolset = OpenAPIToolset(spec_dict=openapi_spec)

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,113 @@ async def test_call_without_header_provider(
13391339

13401340
assert result == {"result": "success"}
13411341

1342+
def test_init_httpx_client_factory_none_by_default(
1343+
self,
1344+
sample_endpoint,
1345+
sample_operation,
1346+
):
1347+
"""httpx_client_factory is None by default."""
1348+
tool = RestApiTool(
1349+
name="test_tool",
1350+
description="Test Tool",
1351+
endpoint=sample_endpoint,
1352+
operation=sample_operation,
1353+
)
1354+
assert tool._httpx_client_factory is None
1355+
1356+
def test_init_with_httpx_client_factory(
1357+
self,
1358+
sample_endpoint,
1359+
sample_operation,
1360+
):
1361+
"""A user-supplied httpx_client_factory is stored on the tool."""
1362+
custom_factory = MagicMock()
1363+
tool = RestApiTool(
1364+
name="test_tool",
1365+
description="Test Tool",
1366+
endpoint=sample_endpoint,
1367+
operation=sample_operation,
1368+
httpx_client_factory=custom_factory,
1369+
)
1370+
assert tool._httpx_client_factory is custom_factory
1371+
1372+
@pytest.mark.asyncio
1373+
async def test_call_uses_custom_httpx_client_factory(
1374+
self,
1375+
mock_tool_context,
1376+
sample_endpoint,
1377+
sample_operation,
1378+
sample_auth_scheme,
1379+
sample_auth_credential,
1380+
):
1381+
"""When a factory is provided, its client is used to issue the request."""
1382+
mock_response = mock.create_autospec(requests.Response, instance=True)
1383+
mock_response.json.return_value = {"result": "success"}
1384+
mock_response.configure_mock(status_code=200)
1385+
1386+
mock_client = mock.create_autospec(
1387+
httpx.AsyncClient, instance=True, spec_set=True
1388+
)
1389+
mock_client.request = AsyncMock(return_value=mock_response)
1390+
# Make the mock client work as an async context manager.
1391+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
1392+
mock_client.__aexit__ = AsyncMock(return_value=None)
1393+
1394+
custom_factory = MagicMock(return_value=mock_client)
1395+
1396+
tool = RestApiTool(
1397+
name="test_tool",
1398+
description="Test Tool",
1399+
endpoint=sample_endpoint,
1400+
operation=sample_operation,
1401+
auth_scheme=sample_auth_scheme,
1402+
auth_credential=sample_auth_credential,
1403+
httpx_client_factory=custom_factory,
1404+
)
1405+
1406+
with patch.object(httpx, "AsyncClient", autospec=True) as mock_default:
1407+
result = await tool.call(args={}, tool_context=mock_tool_context)
1408+
1409+
# Factory must be invoked once and the default client must not be built.
1410+
custom_factory.assert_called_once_with()
1411+
mock_default.assert_not_called()
1412+
mock_client.request.assert_awaited_once()
1413+
assert result == {"result": "success"}
1414+
1415+
@pytest.mark.asyncio
1416+
async def test_call_without_httpx_client_factory_uses_default_client(
1417+
self,
1418+
mock_tool_context,
1419+
sample_endpoint,
1420+
sample_operation,
1421+
sample_auth_scheme,
1422+
sample_auth_credential,
1423+
):
1424+
"""When no factory is provided, the default httpx.AsyncClient is used."""
1425+
mock_response = mock.create_autospec(requests.Response, instance=True)
1426+
mock_response.json.return_value = {"result": "success"}
1427+
mock_response.configure_mock(status_code=200)
1428+
1429+
mock_client = mock.create_autospec(
1430+
httpx.AsyncClient, instance=True, spec_set=True
1431+
)
1432+
mock_client.request = AsyncMock(return_value=mock_response)
1433+
1434+
tool = RestApiTool(
1435+
name="test_tool",
1436+
description="Test Tool",
1437+
endpoint=sample_endpoint,
1438+
operation=sample_operation,
1439+
auth_scheme=sample_auth_scheme,
1440+
auth_credential=sample_auth_credential,
1441+
)
1442+
1443+
with patch.object(
1444+
httpx, "AsyncClient", return_value=mock_client, autospec=True
1445+
) as mock_async_client:
1446+
await tool.call(args={}, tool_context=mock_tool_context)
1447+
assert mock_async_client.called
1448+
13421449
def test_prepare_request_params_extracts_embedded_query_params(
13431450
self, sample_auth_credential, sample_auth_scheme
13441451
):

0 commit comments

Comments
 (0)