Skip to content

Commit 559f0c2

Browse files
committed
Merge remote-tracking branch 'upstream/main' into topic/bump-litellm-cap
2 parents 5cfa649 + ad937fe commit 559f0c2

7 files changed

Lines changed: 166 additions & 20 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ optional-dependencies.docs = [
102102
"myst-parser",
103103
"sphinx<9",
104104
"sphinx-autodoc-typehints",
105+
"sphinx-click",
105106
"sphinx-rtd-theme",
106107
]
107108
optional-dependencies.eval = [

src/google/adk/integrations/agent_identity/gcp_auth_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _construct_auth_credential(
8282
return AuthCredential(
8383
auth_type=AuthCredentialTypes.HTTP,
8484
http=HttpAuth(
85-
scheme="bearer",
85+
scheme="Bearer",
8686
credentials=HttpCredentials(token=response.token),
8787
),
8888
)

src/google/adk/plugins/reflect_retry_tool_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ async def _handle_tool_error(
242242
"""
243243
if self.max_retries == 0:
244244
if self.throw_exception_if_retry_exceeded:
245-
raise error
246-
return self._get_tool_retry_exceed_msg(tool, error, tool_args)
245+
raise self._ensure_exception(error)
246+
return self._get_tool_retry_exceed_msg(tool, tool_args, error)
247247

248248
scope_key = self._get_scope_key(tool_context)
249249
async with self._lock:
@@ -260,7 +260,7 @@ async def _handle_tool_error(
260260

261261
# Max Retry exceeded
262262
if self.throw_exception_if_retry_exceeded:
263-
raise error
263+
raise self._ensure_exception(error)
264264
else:
265265
return self._get_tool_retry_exceed_msg(tool, tool_args, error)
266266

tests/integration/integrations/agent_identity/test_2lo_flow.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@
1818
from typing import Any
1919
from unittest import mock
2020

21-
import pytest
22-
23-
pytest.importorskip(
24-
"google.cloud.iamconnectorcredentials_v1alpha",
25-
reason="Requires google-cloud-iamconnectorcredentials",
26-
)
27-
2821
from google.adk import Agent
2922
from google.adk import Runner
3023
from google.adk.auth.auth_tool import AuthConfig
@@ -34,9 +27,12 @@
3427
from google.adk.integrations.agent_identity import GcpAuthProviderScheme
3528
from google.adk.sessions.in_memory_session_service import InMemorySessionService
3629
from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool
30+
from google.adk.tools.mcp_tool.mcp_tool import McpTool
3731
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest
3832
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse
3933
from google.genai import types
34+
from mcp.types import Tool as McpBaseTool
35+
import pytest
4036

4137
from tests.unittests import testing_utils
4238

@@ -200,3 +196,73 @@ async def test_gcp_agent_identity_2lo_gets_token() -> None:
200196

201197
assert function_response.name == "dummy_tool"
202198
assert DUMMY_TOKEN in str(function_response.response)
199+
200+
201+
@pytest.mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True)
202+
@pytest.mark.asyncio
203+
async def test_gcp_agent_identity_2lo_sends_authorization_header_to_mcp_session(
204+
llm_backend: Any,
205+
) -> None:
206+
"""Ensures a 2LO token from GCP is correctly passed into the outbound MCP session headers."""
207+
CredentialManager._auth_provider_registry._providers.clear()
208+
CredentialManager.register_auth_provider(GcpAuthProvider())
209+
210+
mock_operation = _DummyOperation()
211+
with mock.patch.object(
212+
gcp_auth_provider, "Client", autospec=True
213+
) as mock_gcp:
214+
mock_gcp.return_value.retrieve_credentials.return_value = mock_operation
215+
216+
mock_session_mgr = mock.AsyncMock()
217+
mock_session_mgr.create_session.return_value.call_tool.return_value = (
218+
mock.MagicMock()
219+
)
220+
221+
mcp_tool = McpTool(
222+
mcp_tool=McpBaseTool(
223+
name="dummy_mcp",
224+
description="Dummy MCP tool for testing.",
225+
inputSchema={"type": "object", "properties": {}},
226+
),
227+
mcp_session_manager=mock_session_mgr,
228+
auth_scheme=GcpAuthProviderScheme(
229+
name=TEST_CONNECTOR_2LO, scopes=["test-scope"]
230+
),
231+
)
232+
233+
agent = Agent(
234+
name="test_agent",
235+
model=testing_utils.MockModel.create(
236+
responses=[
237+
types.Part.from_function_call(name="dummy_mcp", args={}),
238+
"Tool executed successfully.",
239+
]
240+
),
241+
instruction="Use dummy_mcp tool.",
242+
tools=[mcp_tool],
243+
)
244+
245+
async for _ in Runner(
246+
app_name="test_mcp_header_app",
247+
agent=agent,
248+
session_service=InMemorySessionService(),
249+
auto_create_session=True,
250+
).run_async(
251+
user_id="test_user",
252+
session_id="session-id-2",
253+
new_message=types.UserContent(parts=[types.Part(text="Run tool.")]),
254+
):
255+
pass
256+
257+
mock_gcp.return_value.retrieve_credentials.assert_called_once_with(
258+
RetrieveCredentialsRequest(
259+
connector=TEST_CONNECTOR_2LO,
260+
user_id="test_user",
261+
scopes=["test-scope"],
262+
force_refresh=False,
263+
)
264+
)
265+
266+
assert mock_session_mgr.create_session.call_args.kwargs.get("headers") == {
267+
"Authorization": f"Bearer {DUMMY_TOKEN}"
268+
}

tests/integration/integrations/agent_identity/test_3lo_flow.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@
1818
from typing import Any
1919
from unittest import mock
2020

21-
import pytest
22-
23-
pytest.importorskip(
24-
"google.cloud.iamconnectorcredentials_v1alpha",
25-
reason="Requires google-cloud-iamconnectorcredentials",
26-
)
27-
2821
from google.adk import Agent
2922
from google.adk import Runner
3023
from google.adk.auth.auth_tool import AuthConfig
@@ -38,6 +31,7 @@
3831
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest
3932
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse
4033
from google.genai import types
34+
import pytest
4135

4236
from tests.unittests import testing_utils
4337

tests/unittests/integrations/agent_identity/test_gcp_auth_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async def test_get_auth_credential_returns_credential_if_available_immediately(
169169
auth_credential = await provider.get_auth_credential(auth_config, context)
170170

171171
assert auth_credential.auth_type == AuthCredentialTypes.HTTP
172-
assert auth_credential.http.scheme == "bearer"
172+
assert auth_credential.http.scheme == "Bearer"
173173
assert auth_credential.http.credentials.token == "test-token"
174174
mock_client.retrieve_credentials.assert_called_once()
175175

@@ -433,7 +433,7 @@ async def test_get_auth_credential_returns_token_if_consent_was_completed(
433433
# Verify
434434
assert auth_credential is not None
435435
assert auth_credential.auth_type == AuthCredentialTypes.HTTP
436-
assert auth_credential.http.scheme == "bearer"
436+
assert auth_credential.http.scheme == "Bearer"
437437
assert auth_credential.http.credentials.token == "test-token"
438438

439439

tests/unittests/plugins/test_reflect_retry_tool_plugin.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,57 @@ async def test_on_tool_error_callback_max_retries_zero(self):
168168
# Should re-raise the original exception when max_retries is 0
169169
self.assertIs(cm.exception, error)
170170

171+
async def test_on_tool_error_callback_max_retries_zero_without_exception(
172+
self,
173+
):
174+
"""Test error callback when max_retries is 0 and exception is disabled."""
175+
mock_tool = self.get_mock_tool()
176+
mock_tool_context = self.get_mock_tool_context()
177+
sample_tool_args = self.get_sample_tool_args()
178+
plugin = ReflectAndRetryToolPlugin(
179+
max_retries=0, throw_exception_if_retry_exceeded=False
180+
)
181+
error = ValueError("Test error")
182+
183+
result = await plugin.on_tool_error_callback(
184+
tool=mock_tool,
185+
tool_args=sample_tool_args,
186+
tool_context=mock_tool_context,
187+
error=error,
188+
)
189+
190+
# Should return a retry exceeded message instead of raising
191+
self.assertIsNotNone(result)
192+
self.assertEqual(result["response_type"], REFLECT_AND_RETRY_RESPONSE_TYPE)
193+
self.assertEqual(result["error_type"], "ValueError")
194+
self.assertEqual(result["retry_count"], 0)
195+
self.assertIn(
196+
"the retry limit has been exceeded", result["reflection_guidance"]
197+
)
198+
199+
async def test_on_tool_error_callback_max_retries_zero_with_dict_error(self):
200+
"""Test error callback when max_retries is 0 and error is a dict."""
201+
mock_tool = self.get_mock_tool()
202+
mock_tool_context = self.get_mock_tool_context()
203+
sample_tool_args = self.get_sample_tool_args()
204+
plugin = CustomErrorExtractionPlugin(
205+
max_retries=0, throw_exception_if_retry_exceeded=True
206+
)
207+
dict_error = {"status": "error", "message": "Custom dict error"}
208+
plugin.set_error_condition(lambda result: dict_error)
209+
210+
with self.assertRaises(Exception) as cm:
211+
await plugin.after_tool_callback(
212+
tool=mock_tool,
213+
tool_args=sample_tool_args,
214+
tool_context=mock_tool_context,
215+
result={"some": "result"},
216+
)
217+
218+
# Should raise an Exception wrapping the dict
219+
self.assertNotIsInstance(cm.exception, TypeError)
220+
self.assertIn("Custom dict error", str(cm.exception))
221+
171222
async def test_on_tool_error_callback_first_failure(self):
172223
"""Test first tool failure creates reflection response."""
173224
plugin = self.get_plugin()
@@ -280,6 +331,40 @@ async def test_max_retries_exceeded_with_exception(self):
280331
# Verify exception properties
281332
self.assertIs(cm.exception, error)
282333

334+
async def test_max_retries_exceeded_with_dict_error(self):
335+
"""Test that Exception is raised when max retries exceeded with dict error."""
336+
mock_tool = self.get_mock_tool()
337+
mock_tool_context = self.get_mock_tool_context()
338+
sample_tool_args = self.get_sample_tool_args()
339+
plugin = CustomErrorExtractionPlugin(
340+
max_retries=1, throw_exception_if_retry_exceeded=True
341+
)
342+
dict_error = {"status": "error", "message": "Custom dict error"}
343+
plugin.set_error_condition(lambda result: dict_error)
344+
345+
# First call should fail and return a retry response
346+
result1 = await plugin.after_tool_callback(
347+
tool=mock_tool,
348+
tool_args=sample_tool_args,
349+
tool_context=mock_tool_context,
350+
result={"some": "result"},
351+
)
352+
self.assertIsNotNone(result1)
353+
self.assertEqual(result1["retry_count"], 1)
354+
355+
# Second call should exceed max_retries and raise
356+
with self.assertRaises(Exception) as cm:
357+
await plugin.after_tool_callback(
358+
tool=mock_tool,
359+
tool_args=sample_tool_args,
360+
tool_context=mock_tool_context,
361+
result={"some": "result"},
362+
)
363+
364+
# Verify exception properties
365+
self.assertNotIsInstance(cm.exception, TypeError)
366+
self.assertIn("Custom dict error", str(cm.exception))
367+
283368
async def test_max_retries_exceeded_without_exception(self):
284369
"""Test max retries exceeded returns failure message when exception is disabled."""
285370
mock_tool = self.get_mock_tool()

0 commit comments

Comments
 (0)