Skip to content

Commit e31de73

Browse files
committed
Standardize tool error capture with ToolErrorData across framework handlers
Signed-off-by: Eric Evans <194135482+ericevans-nv@users.noreply.github.com>
1 parent 69e78d0 commit e31de73

14 files changed

Lines changed: 244 additions & 451 deletions

File tree

packages/nvidia_nat_adk/src/nat/plugins/adk/callback_handler.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nat.data_models.intermediate_step import IntermediateStepPayload
2626
from nat.data_models.intermediate_step import IntermediateStepType
2727
from nat.data_models.intermediate_step import StreamEventData
28+
from nat.data_models.intermediate_step import ToolErrorData
2829
from nat.data_models.intermediate_step import TraceMetadata
2930
from nat.data_models.intermediate_step import UsageInfo
3031
from nat.data_models.profiler_callback import BaseProfilerCallback
@@ -200,8 +201,29 @@ async def wrapped_tool_use(base_tool_instance, *args, **kwargs) -> Any:
200201

201202
return result
202203

203-
except Exception as _e:
204-
logger.exception("BaseTool error occured")
204+
except Exception as e:
205+
logger.error("BaseTool error: %s", e)
206+
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
207+
tool_error: ToolErrorData = ToolErrorData(
208+
content=f"{type(e).__name__}: {e!s}",
209+
error_type=type(e).__name__,
210+
error_message=str(e),
211+
)
212+
self.step_manager.push_intermediate_step(
213+
IntermediateStepPayload(
214+
event_type=IntermediateStepType.TOOL_END,
215+
span_event_timestamp=time.time(),
216+
framework=LLMFrameworkEnum.ADK,
217+
name=tool_name,
218+
data=StreamEventData(
219+
input={
220+
"args": args, "kwargs": dict(kwargs_args)
221+
},
222+
output=tool_error,
223+
),
224+
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
225+
UUID=step_uuid,
226+
))
205227
raise
206228

207229
return wrapped_tool_use

packages/nvidia_nat_adk/tests/test_adk_callback_handler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from nat.data_models.intermediate_step import IntermediateStepType
2525
from nat.data_models.intermediate_step import LLMFrameworkEnum
26+
from nat.data_models.intermediate_step import ToolErrorData
2627
from nat.data_models.profiler_callback import BaseProfilerCallback
2728
from nat.plugins.adk.callback_handler import ADKProfilerHandler
2829

@@ -219,30 +220,34 @@ async def test_tool_use_monkey_patch_functionality(handler, mock_context):
219220

220221
@pytest.mark.asyncio
221222
async def test_tool_use_monkey_patch_with_exception(handler, mock_context):
222-
"""Test tool use monkey patch handles exceptions properly."""
223-
# Create a mock tool instance
223+
"""When a tool raises an exception, TOOL_END event contains ToolErrorData with parsed error details."""
224224
mock_tool_instance = MagicMock()
225-
mock_tool_instance.name = "test_tool"
225+
mock_tool_instance.name = "lookup_tool"
226226

227-
# Create mock original function that raises an exception
228-
mock_original_func = AsyncMock(side_effect=Exception("Tool error"))
227+
mock_original_func = AsyncMock(side_effect=ValueError("Column 'revenue' not found"))
229228
handler._original_tool_call = mock_original_func
230229

231-
# Get the wrapped function
232230
wrapped_func = handler._tool_use_monkey_patch()
233231

234-
# Test that exception is re-raised
235-
with pytest.raises(Exception, match="Tool error"):
232+
with pytest.raises(ValueError, match="Column 'revenue' not found"):
236233
await wrapped_func(mock_tool_instance, "arg1")
237234

238-
# Verify original function was called
239235
mock_original_func.assert_called_once()
240236

241-
# Verify start event was still pushed
242-
assert mock_context.push_intermediate_step.call_count >= 1
237+
assert mock_context.push_intermediate_step.call_count == 2
243238
start_call = mock_context.push_intermediate_step.call_args_list[0][0][0]
244239
assert start_call.event_type == IntermediateStepType.TOOL_START
245240

241+
end_call = mock_context.push_intermediate_step.call_args_list[1][0][0]
242+
assert end_call.event_type == IntermediateStepType.TOOL_END
243+
assert end_call.name == "lookup_tool"
244+
assert isinstance(end_call.data.output, ToolErrorData)
245+
246+
error_data: ToolErrorData = end_call.data.output
247+
assert error_data.content == "ValueError: Column 'revenue' not found"
248+
assert error_data.error_type == "ValueError"
249+
assert error_data.error_message == "Column 'revenue' not found"
250+
246251

247252
@pytest.mark.asyncio
248253
async def test_tool_use_monkey_patch_tool_name_error(handler, mock_context):

packages/nvidia_nat_autogen/src/nat/plugins/autogen/callback_handler.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from nat.data_models.intermediate_step import IntermediateStepPayload
4646
from nat.data_models.intermediate_step import IntermediateStepType
4747
from nat.data_models.intermediate_step import StreamEventData
48+
from nat.data_models.intermediate_step import ToolErrorData
4849
from nat.data_models.intermediate_step import TraceMetadata
4950
from nat.data_models.intermediate_step import UsageInfo
5051
from nat.data_models.profiler_callback import BaseProfilerCallback
@@ -557,15 +558,18 @@ async def wrapped_tool_call(*args: Any, **kwargs: Any) -> Any:
557558
except Exception:
558559
logger.debug("Error getting tool name")
559560

560-
# Extract tool input
561-
tool_input = ""
561+
# Extract tool input from args
562+
# run_json signature: (self, args: Mapping[str, Any], cancellation_token, call_id=None)
563+
# args[0] = self (tool instance)
564+
# args[1] = args (the tool arguments as a Mapping)
565+
tool_input: dict[str, Any] = {}
562566
try:
563567
if len(args) > 1:
564-
call_data = args[1]
565-
if hasattr(call_data, "kwargs"):
566-
tool_input = str(call_data.kwargs)
567-
elif isinstance(call_data, dict):
568-
tool_input = str(call_data.get("kwargs", {}))
568+
tool_args = args[1]
569+
if isinstance(tool_args, dict):
570+
tool_input = dict(tool_args)
571+
elif hasattr(tool_args, "items"):
572+
tool_input = dict(tool_args)
569573
except Exception:
570574
logger.debug("Error extracting tool input")
571575

@@ -590,14 +594,18 @@ async def wrapped_tool_call(*args: Any, **kwargs: Any) -> Any:
590594
output = await original_func(*args, **kwargs)
591595
except Exception as e:
592596
logger.error("Tool execution failed: %s", e)
597+
tool_error: ToolErrorData = ToolErrorData(
598+
content=f"{type(e).__name__}: {e!s}",
599+
error_type=type(e).__name__,
600+
error_message=str(e),
601+
)
593602
handler.step_manager.push_intermediate_step(
594603
IntermediateStepPayload(
595604
event_type=IntermediateStepType.TOOL_END,
596605
span_event_timestamp=time.time(),
597606
framework=LLMFrameworkEnum.AUTOGEN,
598607
name=tool_name,
599-
data=StreamEventData(input=tool_input, output=str(e)),
600-
metadata=TraceMetadata(error=str(e)),
608+
data=StreamEventData(input=tool_input, output=tool_error),
601609
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
602610
UUID=start_uuid,
603611
))

packages/nvidia_nat_autogen/tests/test_callback_handler_autogen.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pytest
2424

25+
from nat.data_models.intermediate_step import ToolErrorData
2526
from nat.plugins.autogen.callback_handler import AutoGenProfilerHandler
2627
from nat.plugins.autogen.callback_handler import ClientPatchInfo
2728
from nat.plugins.autogen.callback_handler import PatchedClients
@@ -561,29 +562,56 @@ async def test_tool_wrapper_handles_dict_input(self, mock_get):
561562

562563
@patch('nat.plugins.autogen.callback_handler.Context.get')
563564
async def test_tool_wrapper_handles_exception(self, mock_get):
564-
"""Test tool wrapper handles tool execution errors."""
565+
"""When a tool raises an exception, TOOL_END event contains ToolErrorData with parsed error details."""
565566
mock_context = Mock()
566567
mock_step_manager = Mock()
567568
mock_context.intermediate_step_manager = mock_step_manager
568569
mock_get.return_value = mock_context
569570

570571
handler = AutoGenProfilerHandler()
571572

572-
original_func = AsyncMock(side_effect=ValueError("Tool failed"))
573+
original_func = AsyncMock(side_effect=ValueError("Column 'revenue' not found"))
573574
wrapped = handler._create_tool_wrapper(original_func)
574575

575576
tool = Mock()
576-
tool.name = "failing_tool"
577+
tool.name = "lookup_tool"
577578
call_data = Mock()
578579
call_data.kwargs = {}
579580

580-
with pytest.raises(ValueError, match="Tool failed"):
581+
with pytest.raises(ValueError, match="Column 'revenue' not found"):
581582
await wrapped(tool, call_data)
582583

583-
# Should have START and error END
584584
assert mock_step_manager.push_intermediate_step.call_count == 2
585585
error_call = mock_step_manager.push_intermediate_step.call_args_list[1][0][0]
586-
assert "Tool failed" in error_call.data.output
586+
assert error_call.name == "lookup_tool"
587+
assert isinstance(error_call.data.output, ToolErrorData)
588+
589+
error_data: ToolErrorData = error_call.data.output
590+
assert error_data.content == "ValueError: Column 'revenue' not found"
591+
assert error_data.error_type == "ValueError"
592+
assert error_data.error_message == "Column 'revenue' not found"
593+
594+
@patch('nat.plugins.autogen.callback_handler.Context.get')
595+
async def test_tool_wrapper_extracts_input_from_run_json_args(self, mock_get):
596+
"""Test tool wrapper extracts input from run_json signature: (self, args: Mapping, ...)."""
597+
mock_context = Mock()
598+
mock_step_manager = Mock()
599+
mock_context.intermediate_step_manager = mock_step_manager
600+
mock_get.return_value = mock_context
601+
602+
handler = AutoGenProfilerHandler()
603+
604+
original_func = AsyncMock(return_value="result")
605+
wrapped = handler._create_tool_wrapper(original_func)
606+
607+
tool = Mock()
608+
tool.name = "failing_lookup"
609+
tool_args = {"query": "revenue"}
610+
611+
await wrapped(tool, tool_args, Mock())
612+
613+
start_event = mock_step_manager.push_intermediate_step.call_args_list[0][0][0]
614+
assert start_event.data.input == {"query": "revenue"}
587615

588616

589617
class TestIntegration:

packages/nvidia_nat_core/src/nat/data_models/intermediate_step.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ class IntermediateStepState(StrEnum):
6868
END = "END"
6969

7070

71+
class ToolErrorData(BaseModel):
72+
"""ToolErrorData is a data model that represents the output field in a TOOL_END event when an error occurs."""
73+
74+
content: str = Field(description="Full error string, e.g. 'ValueError: Column not found'")
75+
error_type: str = Field(description="Exception type, e.g. 'ValueError'")
76+
error_message: str = Field(description="Error message without type, e.g. 'Column not found'")
77+
78+
7179
class StreamEventData(BaseModel):
7280
"""
7381
StreamEventData is a data model that represents the data field in an streaming event.

packages/nvidia_nat_core/src/nat/utils/atif_converter.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from nat.data_models.intermediate_step import IntermediateStepCategory
4747
from nat.data_models.intermediate_step import IntermediateStepState
4848
from nat.data_models.intermediate_step import IntermediateStepType
49+
from nat.data_models.intermediate_step import ToolErrorData
4950
from nat.data_models.intermediate_step import TraceMetadata
5051

5152
logger = logging.getLogger(__name__)
@@ -101,29 +102,6 @@ def _safe_str(value: Any) -> str:
101102
return str(value)
102103

103104

104-
def _extract_tool_error(output: Any) -> dict[str, str] | None:
105-
"""Extract error metadata from a tool output for ``step.extra["tool_errors"]``."""
106-
# TODO: return a model instead of a plain dict once ATIF spec adds error support
107-
status: str | None = getattr(output, "status", None) or (output.get("status") if isinstance(output, dict) else None)
108-
if status != "error":
109-
return None
110-
content: str = (getattr(output, "content", None) or (output.get("content") if isinstance(output, dict) else None)
111-
or _safe_str(output))
112-
error_type: str = "Unknown"
113-
error_message: str = content
114-
if ":" in content:
115-
candidate: str = content.split(":", 1)[0].strip()
116-
if candidate.isidentifier():
117-
error_type = candidate
118-
error_message = content.split(":", 1)[1].strip()
119-
return {
120-
"error": content,
121-
"error_type": error_type,
122-
"error_message": error_message,
123-
"status": "error",
124-
}
125-
126-
127105
def _extract_user_input(value: Any) -> str:
128106
"""Extract the user-facing input text from a workflow start payload.
129107
@@ -366,11 +344,15 @@ def _flush_pending() -> None:
366344
call_id = f"call_{ist.UUID}"
367345
tc = ATIFToolCall(tool_call_id=call_id, function_name=tool_name, arguments=tool_input)
368346
obs = ATIFObservationResult(source_call_id=call_id, content=tool_output)
369-
tool_error: dict[str, str] | None = _extract_tool_error(raw_output)
370347

371-
if tool_error is not None:
372-
tool_error["tool"] = tool_name
373-
extra: dict[str, Any] | None = ({"tool_errors": [tool_error]} if tool_error else None)
348+
tool_error: dict[str, str] | None = None
349+
if isinstance(raw_output, ToolErrorData):
350+
tool_error = {
351+
"tool": tool_name,
352+
"error": raw_output.content,
353+
"error_type": raw_output.error_type,
354+
"error_message": raw_output.error_message,
355+
}
374356

375357
if pending is not None:
376358
pending.tool_calls.append(tc)
@@ -379,7 +361,7 @@ def _flush_pending() -> None:
379361
pending.extra.setdefault("tool_errors", []).append(tool_error)
380362
pending.tool_ancestry.append(_atif_ancestry_from_ist(ist))
381363
else:
382-
extra = _atif_step_extra_model_from_ist(ist).model_dump(exclude_none=True)
364+
extra: dict[str, Any] = _atif_step_extra_model_from_ist(ist).model_dump(exclude_none=True)
383365
if tool_error:
384366
extra.setdefault("tool_errors", []).append(tool_error)
385367
atif_steps.append(
@@ -552,9 +534,16 @@ def push(self, ist: IntermediateStep) -> ATIFStep | None:
552534
call_id = f"call_{ist.UUID}"
553535
tc = ATIFToolCall(tool_call_id=call_id, function_name=tool_name, arguments=tool_input)
554536
obs = ATIFObservationResult(source_call_id=call_id, content=tool_output)
555-
tool_error: dict[str, str] | None = _extract_tool_error(raw_output)
556-
if tool_error is not None:
557-
tool_error["tool"] = tool_name
537+
538+
tool_error: dict[str, str] | None = None
539+
if isinstance(raw_output, ToolErrorData):
540+
tool_error = {
541+
"tool": tool_name,
542+
"error": raw_output.content,
543+
"error_type": raw_output.error_type,
544+
"error_message": raw_output.error_message,
545+
}
546+
558547
if self._pending is not None:
559548
self._pending.tool_calls.append(tc)
560549
self._pending.observations.append(obs)
@@ -563,7 +552,7 @@ def push(self, ist: IntermediateStep) -> ATIFStep | None:
563552
self._pending.tool_ancestry.append(_atif_ancestry_from_ist(ist))
564553
return None
565554

566-
extra = _atif_step_extra_model_from_ist(ist).model_dump(exclude_none=True)
555+
extra: dict[str, Any] = _atif_step_extra_model_from_ist(ist).model_dump(exclude_none=True)
567556
if tool_error:
568557
extra.setdefault("tool_errors", []).append(tool_error)
569558
orphan_step = ATIFStep(

0 commit comments

Comments
 (0)