Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from ..tracing import Span, SpanError, function_span, get_current_trace
from ..util import _coro, _error_tracing
from ..util._approvals import evaluate_needs_approval_setting
from ..util._tool_errors import get_trace_tool_error
from ..util._types import MaybeAwaitable
from ._asyncio_progress import get_function_tool_task_progress_deadline
from .agent_bindings import AgentBindings, bind_public_agent
Expand Down Expand Up @@ -152,7 +153,6 @@
"execute_approved_tools",
]

REDACTED_TOOL_ERROR_MESSAGE = "Tool execution failed. Error details are redacted."
TToolSpanResult = TypeVar("TToolSpanResult")
_FUNCTION_TOOL_CANCELLED_DRAIN_SECONDS = 0.25
_FUNCTION_TOOL_CANCELLED_IMMEDIATE_STEP_LIMIT = 64
Expand Down Expand Up @@ -1013,11 +1013,6 @@ def format_shell_error(error: Exception | BaseException | Any) -> str:
return repr(error)


def get_trace_tool_error(*, trace_include_sensitive_data: bool, error_message: str) -> str:
"""Return a trace-safe tool error string based on the sensitive-data setting."""
return error_message if trace_include_sensitive_data else REDACTED_TOOL_ERROR_MESSAGE


async def with_tool_function_span(
*,
config: RunConfig,
Expand Down Expand Up @@ -1570,10 +1565,14 @@ async def _run_single_tool(
agent_hooks=agent_hooks,
)
except Exception as e:
trace_error = get_trace_tool_error(
trace_include_sensitive_data=self.config.trace_include_sensitive_data,
error_message=str(e),
)
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": func_tool.name, "error": str(e)},
data={"tool_name": func_tool.name, "error": trace_error},
)
)
if isinstance(e, AgentsException):
Expand Down Expand Up @@ -1732,10 +1731,14 @@ async def _invoke_tool_and_run_post_invoke(
if result is None:
raise

trace_error = get_trace_tool_error(
trace_include_sensitive_data=self.config.trace_include_sensitive_data,
error_message=str(e),
)
_error_tracing.attach_error_to_current_span(
SpanError(
message="Tool execution cancelled",
data={"tool_name": func_tool.name, "error": str(e)},
data={"tool_name": func_tool.name, "error": trace_error},
)
)
real_result = result
Expand Down
47 changes: 40 additions & 7 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .tool_guardrails import ToolInputGuardrail, ToolOutputGuardrail
from .tracing import SpanError
from .util import _error_tracing
from .util._tool_errors import get_trace_tool_error
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
Expand Down Expand Up @@ -422,7 +423,7 @@ class _FailureHandlingFunctionToolInvoker:
def __init__(
self,
invoke_tool_impl: Callable[[ToolContext[Any], str], Awaitable[Any]],
on_handled_error: Callable[[FunctionTool, Exception, str], None],
on_handled_error: Callable[[FunctionTool, Exception, str, ToolContext[Any]], None],
*,
function_tool: FunctionTool | None = None,
) -> None:
Expand Down Expand Up @@ -457,7 +458,7 @@ async def __call__(self, ctx: ToolContext[Any], input: str) -> Any:
if result is None:
raise

self._on_handled_error(self._function_tool, e, input)
self._on_handled_error(self._function_tool, e, input, ctx)
return result


Expand All @@ -466,6 +467,26 @@ def with_function_tool_failure_error_handler(
on_handled_error: Callable[[FunctionTool, Exception, str], None],
) -> Callable[[ToolContext[Any], str], Awaitable[Any]]:
"""Wrap a tool invoker so copied FunctionTools resolve failure policy against themselves."""

def _on_handled_error_with_context(
function_tool: FunctionTool,
error: Exception,
input_json: str,
_context: ToolContext[Any],
) -> None:
on_handled_error(function_tool, error, input_json)

return _with_context_function_tool_failure_error_handler(
invoke_tool_impl,
_on_handled_error_with_context,
)


def _with_context_function_tool_failure_error_handler(
invoke_tool_impl: Callable[[ToolContext[Any], str], Awaitable[Any]],
on_handled_error: Callable[[FunctionTool, Exception, str, ToolContext[Any]], None],
) -> Callable[[ToolContext[Any], str], Awaitable[Any]]:
"""Wrap a tool invoker with context-aware handled-error reporting."""
return _FailureHandlingFunctionToolInvoker(invoke_tool_impl, on_handled_error)


Expand All @@ -475,7 +496,7 @@ def _build_wrapped_function_tool(
description: str,
params_json_schema: dict[str, Any],
invoke_tool_impl: Callable[[ToolContext[Any], str], Awaitable[Any]],
on_handled_error: Callable[[FunctionTool, Exception, str], None],
on_handled_error: Callable[[FunctionTool, Exception, str, ToolContext[Any]], None],
failure_error_function: ToolErrorFunction | None | object = _UNSET_FAILURE_ERROR_FUNCTION,
strict_json_schema: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
Expand All @@ -493,7 +514,7 @@ def _build_wrapped_function_tool(
tool_origin: ToolOrigin | None = None,
) -> FunctionTool:
"""Create a FunctionTool with copied-tool-aware failure handling bound in one place."""
on_invoke_tool = with_function_tool_failure_error_handler(
on_invoke_tool = _with_context_function_tool_failure_error_handler(
invoke_tool_impl,
on_handled_error,
)
Expand Down Expand Up @@ -1377,24 +1398,36 @@ def _build_handled_function_tool_error_handler(
span_message_for_json_decode_error: str | None = None,
include_input_json_in_logs: bool = True,
include_tool_name_in_log_messages: bool = True,
) -> Callable[[FunctionTool, Exception, str], None]:
) -> Callable[[FunctionTool, Exception, str, ToolContext[Any]], None]:
"""Create a consistent handled-error reporter for wrapped FunctionTools."""

def _on_handled_error(function_tool: FunctionTool, error: Exception, input_json: str) -> None:
def _on_handled_error(
function_tool: FunctionTool,
error: Exception,
input_json: str,
context: ToolContext[Any],
) -> None:
json_decode_error = _extract_tool_argument_json_error(error)
if json_decode_error is not None and span_message_for_json_decode_error is not None:
resolved_span_message = span_message_for_json_decode_error
span_error_detail = str(json_decode_error)
else:
resolved_span_message = span_message
span_error_detail = str(error)
trace_include_sensitive_data = (
context.run_config is None or context.run_config.trace_include_sensitive_data
)
trace_error = get_trace_tool_error(
trace_include_sensitive_data=trace_include_sensitive_data,
error_message=span_error_detail,
)

_error_tracing.attach_error_to_current_span(
SpanError(
message=resolved_span_message,
data={
"tool_name": function_tool.name,
"error": span_error_detail,
"error": trace_error,
},
)
)
Expand Down
8 changes: 8 additions & 0 deletions src/agents/util/_tool_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Helpers for rendering tool errors in trace-safe form."""

REDACTED_TOOL_ERROR_MESSAGE = "Tool execution failed. Error details are redacted."


def get_trace_tool_error(*, trace_include_sensitive_data: bool, error_message: str) -> str:
"""Return a trace-safe tool error string based on the sensitive-data setting."""
return error_message if trace_include_sensitive_data else REDACTED_TOOL_ERROR_MESSAGE
123 changes: 121 additions & 2 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
)


def _function_span_names() -> list[str]:
names: list[str] = []
def _function_spans() -> list[dict[str, Any]]:
function_spans: list[dict[str, Any]] = []
for span in SPAN_PROCESSOR_TESTING.get_ordered_spans(including_empty=True):
exported = span.export()
if not exported:
Expand All @@ -104,6 +104,16 @@ def _function_span_names() -> list[str]:
continue
if span_data.get("type") != "function":
continue
function_spans.append(exported)
return function_spans


def _function_span_names() -> list[str]:
names: list[str] = []
for exported in _function_spans():
span_data = exported.get("span_data")
if not isinstance(span_data, dict):
continue
name = span_data.get("name")
if isinstance(name, str):
names.append(name)
Expand Down Expand Up @@ -509,6 +519,78 @@ async def _error_tool() -> str:
await get_execute_result(agent, response)


@pytest.mark.asyncio
async def test_function_tool_error_trace_respects_sensitive_data_setting():
async def _error_tool() -> str:
raise ValueError("secret-token-123")

error_tool = function_tool(
_error_tool,
name_override="error_tool",
failure_error_function=None,
)
agent = Agent(name="test", tools=[error_tool])
response = ModelResponse(
output=[get_function_tool_call("error_tool", "{}", call_id="1")],
usage=Usage(),
response_id=None,
)

with trace("test"):
with pytest.raises(UserError, match="Error running tool error_tool: secret-token-123"):
await get_execute_result(
agent,
response,
run_config=RunConfig(trace_include_sensitive_data=False),
)

function_spans = _function_spans()

assert len(function_spans) == 1
error = function_spans[0]["error"]
assert error["message"] == "Error running tool"
assert error["data"]["tool_name"] == "error_tool"
assert error["data"]["error"] == "Tool execution failed. Error details are redacted."
assert "secret-token-123" not in str(error)


@pytest.mark.asyncio
async def test_default_function_tool_error_trace_respects_sensitive_data_setting():
async def _error_tool() -> str:
raise ValueError("secret-token-123")

error_tool = function_tool(_error_tool, name_override="error_tool")
agent = Agent(name="test", tools=[error_tool])
response = ModelResponse(
output=[get_function_tool_call("error_tool", "{}", call_id="1")],
usage=Usage(),
response_id=None,
)

with trace("test"):
result = await get_execute_result(
agent,
response,
run_config=RunConfig(trace_include_sensitive_data=False),
)

assert len(result.generated_items) == 2
assert isinstance(result.next_step, NextStepRunAgain)
assert_item_is_function_tool_call_output(
result.generated_items[1],
"An error occurred while running the tool. Please try again. Error: secret-token-123",
)

function_spans = _function_spans()

assert len(function_spans) == 1
error = function_spans[0]["error"]
assert error["message"] == "Error running tool (non-fatal)"
assert error["data"]["tool_name"] == "error_tool"
assert error["data"]["error"] == "Tool execution failed. Error details are redacted."
assert "secret-token-123" not in str(error)


@pytest.mark.asyncio
async def test_multiple_tool_calls_still_raise_when_sibling_cancelled():
async def _ok_tool() -> str:
Expand Down Expand Up @@ -770,6 +852,43 @@ async def _cancel_tool() -> str:
)


@pytest.mark.asyncio
async def test_cancelled_function_tool_error_trace_respects_sensitive_data_setting():
async def _cancel_tool() -> str:
raise asyncio.CancelledError("secret-token-123")

cancel_tool = function_tool(_cancel_tool, name_override="cancel_tool")
agent = Agent(name="test", tools=[cancel_tool])
response = ModelResponse(
output=[get_function_tool_call("cancel_tool", "{}", call_id="1")],
usage=Usage(),
response_id=None,
)

with trace("test"):
result = await get_execute_result(
agent,
response,
run_config=RunConfig(trace_include_sensitive_data=False),
)

assert len(result.generated_items) == 2
assert isinstance(result.next_step, NextStepRunAgain)
assert_item_is_function_tool_call_output(
result.generated_items[1],
"An error occurred while running the tool. Please try again. Error: secret-token-123",
)

function_spans = _function_spans()

assert len(function_spans) == 1
error = function_spans[0]["error"]
assert error["message"] == "Tool execution cancelled"
assert error["data"]["tool_name"] == "cancel_tool"
assert error["data"]["error"] == "Tool execution failed. Error details are redacted."
assert "secret-token-123" not in str(error)


@pytest.mark.asyncio
async def test_multiple_tool_calls_surface_hook_failure_over_sibling_cancellation():
hook_started = asyncio.Event()
Expand Down