Skip to content

Commit 1deab6d

Browse files
xuanyang15copybara-github
authored andcommitted
fix: Fix exception handling and argument order in ReflectRetryToolPlugin
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 906753501
1 parent 684a6e7 commit 1deab6d

4 files changed

Lines changed: 89 additions & 50 deletions

File tree

src/google/adk/optimization/local_eval_sampler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,7 @@ def _extract_eval_data(
289289
for eval_metric_result in per_invocation_result.eval_metric_results:
290290
eval_metric_results.append({
291291
"metric_name": eval_metric_result.metric_name,
292-
"score": (
293-
round(eval_metric_result.score, 2)
294-
if eval_metric_result.score is not None
295-
else None
296-
), # accurate enough
292+
"score": round(eval_metric_result.score, 2), # accurate enough
297293
"eval_status": eval_metric_result.eval_status.name,
298294
})
299295
per_invocation_result_dict = {

src/google/adk/plugins/reflect_retry_tool_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ async def _handle_tool_error(
242242
"""
243243
if self.max_retries == 0:
244244
if self.throw_exception_if_retry_exceeded:
245-
raise error
246-
return self._get_tool_retry_exceed_msg(tool, error, tool_args)
245+
raise self._ensure_exception(error)
246+
return self._get_tool_retry_exceed_msg(tool, tool_args, error)
247247

248248
scope_key = self._get_scope_key(tool_context)
249249
async with self._lock:
@@ -260,7 +260,7 @@ async def _handle_tool_error(
260260

261261
# Max Retry exceeded
262262
if self.throw_exception_if_retry_exceeded:
263-
raise error
263+
raise self._ensure_exception(error)
264264
else:
265265
return self._get_tool_retry_exceed_msg(tool, tool_args, error)
266266

tests/unittests/optimization/local_eval_sampler_test.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -338,48 +338,6 @@ async def test_extract_eval_data(mocker):
338338
]
339339

340340

341-
def test_extract_eval_data_preserves_none_metric_score(mocker):
342-
mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager)
343-
mock_eval_case = mocker.MagicMock()
344-
mock_eval_case.conversation_scenario = "test_scenario"
345-
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
346-
347-
mock_metric_result = mocker.MagicMock(spec=EvalMetricResult)
348-
mock_metric_result.metric_name = "test_metric"
349-
mock_metric_result.score = None
350-
mock_metric_result.eval_status = EvalStatus.NOT_EVALUATED
351-
352-
mock_per_inv_result = mocker.MagicMock(spec=EvalMetricResultPerInvocation)
353-
mock_per_inv_result.actual_invocation = mocker.MagicMock(spec=Invocation)
354-
mock_per_inv_result.expected_invocation = mocker.MagicMock(spec=Invocation)
355-
mock_per_inv_result.eval_metric_results = [mock_metric_result]
356-
357-
mock_eval_result = mocker.MagicMock(spec=EvalCaseResult)
358-
mock_eval_result.eval_id = "t1"
359-
mock_eval_result.eval_metric_result_per_invocation = [mock_per_inv_result]
360-
361-
mocker.patch(
362-
"google.adk.optimization.local_eval_sampler.extract_single_invocation_info",
363-
side_effect=[{"info": "actual"}, {"info": "expected"}],
364-
)
365-
366-
config = LocalEvalSamplerConfig(
367-
eval_config=EvalConfig(),
368-
app_name="test_app",
369-
train_eval_set="train_set",
370-
train_eval_case_ids=["t1"],
371-
)
372-
interface = LocalEvalSampler(config, mock_eval_sets_manager)
373-
374-
eval_data = interface._extract_eval_data("train_set", [mock_eval_result])
375-
376-
assert eval_data["t1"]["invocations"][0]["eval_metric_results"] == [{
377-
"metric_name": "test_metric",
378-
"score": None,
379-
"eval_status": "NOT_EVALUATED",
380-
}]
381-
382-
383341
@pytest.mark.asyncio
384342
async def test_sample_and_score(mocker):
385343
# Mock results

tests/unittests/plugins/test_reflect_retry_tool_plugin.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,57 @@ async def test_on_tool_error_callback_max_retries_zero(self):
168168
# Should re-raise the original exception when max_retries is 0
169169
self.assertIs(cm.exception, error)
170170

171+
async def test_on_tool_error_callback_max_retries_zero_without_exception(
172+
self,
173+
):
174+
"""Test error callback when max_retries is 0 and exception is disabled."""
175+
mock_tool = self.get_mock_tool()
176+
mock_tool_context = self.get_mock_tool_context()
177+
sample_tool_args = self.get_sample_tool_args()
178+
plugin = ReflectAndRetryToolPlugin(
179+
max_retries=0, throw_exception_if_retry_exceeded=False
180+
)
181+
error = ValueError("Test error")
182+
183+
result = await plugin.on_tool_error_callback(
184+
tool=mock_tool,
185+
tool_args=sample_tool_args,
186+
tool_context=mock_tool_context,
187+
error=error,
188+
)
189+
190+
# Should return a retry exceeded message instead of raising
191+
self.assertIsNotNone(result)
192+
self.assertEqual(result["response_type"], REFLECT_AND_RETRY_RESPONSE_TYPE)
193+
self.assertEqual(result["error_type"], "ValueError")
194+
self.assertEqual(result["retry_count"], 0)
195+
self.assertIn(
196+
"the retry limit has been exceeded", result["reflection_guidance"]
197+
)
198+
199+
async def test_on_tool_error_callback_max_retries_zero_with_dict_error(self):
200+
"""Test error callback when max_retries is 0 and error is a dict."""
201+
mock_tool = self.get_mock_tool()
202+
mock_tool_context = self.get_mock_tool_context()
203+
sample_tool_args = self.get_sample_tool_args()
204+
plugin = CustomErrorExtractionPlugin(
205+
max_retries=0, throw_exception_if_retry_exceeded=True
206+
)
207+
dict_error = {"status": "error", "message": "Custom dict error"}
208+
plugin.set_error_condition(lambda result: dict_error)
209+
210+
with self.assertRaises(Exception) as cm:
211+
await plugin.after_tool_callback(
212+
tool=mock_tool,
213+
tool_args=sample_tool_args,
214+
tool_context=mock_tool_context,
215+
result={"some": "result"},
216+
)
217+
218+
# Should raise an Exception wrapping the dict
219+
self.assertNotIsInstance(cm.exception, TypeError)
220+
self.assertIn("Custom dict error", str(cm.exception))
221+
171222
async def test_on_tool_error_callback_first_failure(self):
172223
"""Test first tool failure creates reflection response."""
173224
plugin = self.get_plugin()
@@ -280,6 +331,40 @@ async def test_max_retries_exceeded_with_exception(self):
280331
# Verify exception properties
281332
self.assertIs(cm.exception, error)
282333

334+
async def test_max_retries_exceeded_with_dict_error(self):
335+
"""Test that Exception is raised when max retries exceeded with dict error."""
336+
mock_tool = self.get_mock_tool()
337+
mock_tool_context = self.get_mock_tool_context()
338+
sample_tool_args = self.get_sample_tool_args()
339+
plugin = CustomErrorExtractionPlugin(
340+
max_retries=1, throw_exception_if_retry_exceeded=True
341+
)
342+
dict_error = {"status": "error", "message": "Custom dict error"}
343+
plugin.set_error_condition(lambda result: dict_error)
344+
345+
# First call should fail and return a retry response
346+
result1 = await plugin.after_tool_callback(
347+
tool=mock_tool,
348+
tool_args=sample_tool_args,
349+
tool_context=mock_tool_context,
350+
result={"some": "result"},
351+
)
352+
self.assertIsNotNone(result1)
353+
self.assertEqual(result1["retry_count"], 1)
354+
355+
# Second call should exceed max_retries and raise
356+
with self.assertRaises(Exception) as cm:
357+
await plugin.after_tool_callback(
358+
tool=mock_tool,
359+
tool_args=sample_tool_args,
360+
tool_context=mock_tool_context,
361+
result={"some": "result"},
362+
)
363+
364+
# Verify exception properties
365+
self.assertNotIsInstance(cm.exception, TypeError)
366+
self.assertIn("Custom dict error", str(cm.exception))
367+
283368
async def test_max_retries_exceeded_without_exception(self):
284369
"""Test max retries exceeded returns failure message when exception is disabled."""
285370
mock_tool = self.get_mock_tool()

0 commit comments

Comments
 (0)