Skip to content

Commit 3e8fa48

Browse files
committed
Fix Python tool error handling
1 parent 843b644 commit 3e8fa48

3 files changed

Lines changed: 91 additions & 37 deletions

File tree

python/copilot/tools.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,35 +115,26 @@ def decorator(fn: Callable[..., Any]) -> Tool:
115115
schema = ptype.model_json_schema()
116116

117117
async def wrapped_handler(invocation: ToolInvocation) -> ToolResult:
118-
try:
119-
# Build args based on detected signature
120-
call_args = []
121-
if takes_params:
122-
args = invocation["arguments"] or {}
123-
if ptype is not None and _is_pydantic_model(ptype):
124-
call_args.append(ptype.model_validate(args))
125-
else:
126-
call_args.append(args)
127-
if takes_invocation:
128-
call_args.append(invocation)
129-
130-
result = fn(*call_args)
131-
132-
if inspect.isawaitable(result):
133-
result = await result
134-
135-
return _normalize_result(result)
136-
137-
except Exception as exc:
138-
# Don't expose detailed error information to the LLM for security reasons.
139-
# The actual error is stored in the 'error' field for debugging.
140-
return ToolResult(
141-
textResultForLlm="Invoking this tool produced an error. "
142-
"Detailed information is not available.",
143-
resultType="failure",
144-
error=str(exc),
145-
toolTelemetry={},
146-
)
118+
# Build args based on detected signature.
119+
# Exceptions are NOT caught here — they propagate to the SDK's
120+
# _execute_tool_call, which records errors on the execute_tool
121+
# span and builds a safe ToolResult for the LLM.
122+
call_args = []
123+
if takes_params:
124+
args = invocation["arguments"] or {}
125+
if ptype is not None and _is_pydantic_model(ptype):
126+
call_args.append(ptype.model_validate(args))
127+
else:
128+
call_args.append(args)
129+
if takes_invocation:
130+
call_args.append(invocation)
131+
132+
result = fn(*call_args)
133+
134+
if inspect.isawaitable(result):
135+
result = await result
136+
137+
return _normalize_result(result)
147138

148139
return Tool(
149140
name=tool_name,

python/e2e/test_tools_unit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def test_tool(params: Params) -> str:
169169
assert received_params is not None
170170
assert received_params.value == "hello"
171171

172-
async def test_handler_error_is_hidden_from_llm(self):
172+
async def test_handler_error_propagates(self):
173+
"""Exceptions from tool handlers propagate (caught by _execute_tool_call in client.py)."""
174+
173175
class Params(BaseModel):
174176
pass
175177

@@ -184,13 +186,11 @@ def failing_tool(params: Params, invocation: ToolInvocation) -> str:
184186
"arguments": {},
185187
}
186188

187-
result = await failing_tool.handler(invocation)
188-
189-
assert result["resultType"] == "failure"
190-
assert "secret error message" not in result["textResultForLlm"]
191-
assert "error" in result["textResultForLlm"].lower()
192-
# But the actual error is stored internally
193-
assert result["error"] == "secret error message"
189+
# Exceptions propagate from define_tool handlers — the SDK's
190+
# _execute_tool_call catches them, records telemetry, and builds
191+
# a safe ToolResult that hides error details from the LLM.
192+
with pytest.raises(ValueError, match="secret error message"):
193+
await failing_tool.handler(invocation)
194194

195195
async def test_function_style_api(self):
196196
class Params(BaseModel):

python/test_opentelemetry.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,69 @@ def test_records_error_on_span(self, _reset_otel_globals):
495495
assert s.attributes[ATTR_ERROR_TYPE] == "ValueError"
496496
assert s.status.status_code == trace.StatusCode.ERROR
497497

498+
def test_execute_tool_error_from_define_tool_handler(self, _reset_otel_globals):
499+
"""Verify that errors from @define_tool handlers propagate and get recorded on spans.
500+
501+
This validates the fix where @define_tool no longer catches exceptions internally,
502+
allowing _execute_tool_call to record error.type and ERROR status on the
503+
execute_tool span — consistent with Node.js, .NET, and Go SDKs.
504+
"""
505+
from copilot import ToolInvocation, define_tool
506+
507+
exporter, reader, tp, mp = _get_exporter_and_reader(_reset_otel_globals)
508+
telemetry = _make_telemetry(tracer_provider=tp, meter_provider=mp)
509+
510+
# Use zero-param handler signature to avoid Pydantic + from __future__ import annotations issue
511+
@define_tool(description="A tool that always fails")
512+
def failing_tool() -> str:
513+
raise RuntimeError("deliberate failure")
514+
515+
# Start an execute_tool span (as _execute_tool_call would)
516+
span = telemetry.start_execute_tool_span(
517+
tool_name="failing_tool",
518+
tool_call_id="tc-fail",
519+
description="A tool that always fails",
520+
arguments={},
521+
)
522+
523+
# Simulate _execute_tool_call: invoke the handler, catch the error, record it
524+
invocation: ToolInvocation = {
525+
"session_id": "s1",
526+
"tool_call_id": "tc-fail",
527+
"tool_name": "failing_tool",
528+
"arguments": {},
529+
}
530+
operation_error = None
531+
try:
532+
import asyncio
533+
534+
loop = asyncio.new_event_loop()
535+
loop.run_until_complete(failing_tool.handler(invocation))
536+
loop.close()
537+
except Exception as exc:
538+
operation_error = exc
539+
telemetry.record_error(span, exc)
540+
541+
span.end()
542+
543+
# The exception MUST have propagated (not swallowed by @define_tool)
544+
assert operation_error is not None, "@define_tool must not catch handler exceptions"
545+
assert isinstance(operation_error, RuntimeError)
546+
547+
# The span MUST have ERROR status and error.type
548+
s = exporter.get_finished_spans()[0]
549+
assert s.status.status_code == trace.StatusCode.ERROR
550+
assert s.attributes[ATTR_ERROR_TYPE] == "RuntimeError"
551+
552+
# Operation duration metric should include error.type
553+
telemetry.record_operation_duration(
554+
0.1, None, None, "github", None, None, operation_error, OP_EXECUTE_TOOL
555+
)
556+
dps = _get_metric_data_points(reader, METRIC_OPERATION_DURATION)
557+
assert len(dps) > 0
558+
error_dp = [dp for dp in dps if dp.attributes.get(ATTR_ERROR_TYPE) == "RuntimeError"]
559+
assert len(error_dp) > 0, "duration metric includes error.type for failed tool"
560+
498561

499562
# ---------------------------------------------------------------------------
500563
# Tests: Tool result recording

0 commit comments

Comments
 (0)