Skip to content

Commit caac425

Browse files
MorabbinCopilot
andcommitted
Address code review: tighter guards, default ResultType, session cleanup
- Node: validate resultType against allowed values in isToolResultObject - Go: default empty ResultType to 'success' (or 'failure' when error set) - Python: use _from_exception attribute instead of sentinel string match - Python/Go: disconnect sessions in e2e tests to avoid leaking state Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 9700109 commit caac425

6 files changed

Lines changed: 90 additions & 48 deletions

File tree

go/internal/e2e/tool_results_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ func TestToolResults(t *testing.T) {
5353
if !strings.Contains(strings.ToLower(content), "sunny") && !strings.Contains(content, "72") {
5454
t.Errorf("Expected answer to mention sunny or 72, got %q", content)
5555
}
56+
57+
if err := session.Disconnect(); err != nil {
58+
t.Errorf("Failed to disconnect session: %v", err)
59+
}
5660
})
5761

5862
t.Run("should handle tool result with failure resulttype", func(t *testing.T) {
@@ -97,6 +101,10 @@ func TestToolResults(t *testing.T) {
97101
if !strings.Contains(strings.ToLower(content), "service is down") {
98102
t.Errorf("Expected 'service is down', got %q", content)
99103
}
104+
105+
if err := session.Disconnect(); err != nil {
106+
t.Errorf("Failed to disconnect session: %v", err)
107+
}
100108
})
101109

102110
t.Run("should preserve tooltelemetry and not stringify structured results for llm", func(t *testing.T) {
@@ -167,5 +175,9 @@ func TestToolResults(t *testing.T) {
167175
if strings.Contains(toolResults[0].Content, "resultType") {
168176
t.Error("Tool result content should not contain 'resultType'")
169177
}
178+
179+
if err := session.Disconnect(); err != nil {
180+
t.Errorf("Failed to disconnect session: %v", err)
181+
}
170182
})
171183
}

go/session.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,15 +625,23 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string,
625625
textResultForLLM = fmt.Sprintf("%v", result)
626626
}
627627

628+
// Default ResultType to "success" when unset, or "failure" when there's an error.
629+
effectiveResultType := result.ResultType
630+
if effectiveResultType == "" {
631+
if result.Error != "" {
632+
effectiveResultType = "failure"
633+
} else {
634+
effectiveResultType = "success"
635+
}
636+
}
637+
628638
rpcResult := rpc.ResultUnion{
629639
ResultResult: &rpc.ResultResult{
630640
TextResultForLlm: textResultForLLM,
631641
ToolTelemetry: result.ToolTelemetry,
642+
ResultType: &effectiveResultType,
632643
},
633644
}
634-
if result.ResultType != "" {
635-
rpcResult.ResultResult.ResultType = &result.ResultType
636-
}
637645
if result.Error != "" {
638646
rpcResult.ResultResult.Error = &result.Error
639647
}

nodejs/src/session.ts

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,15 +1050,27 @@ export class CopilotSession {
10501050

10511051
/**
10521052
* Type guard that checks whether a value is a {@link ToolResultObject}.
1053-
* A valid object must have a string `textResultForLlm` and a string `resultType`.
1053+
* A valid object must have a string `textResultForLlm` and a recognized `resultType`.
10541054
*/
10551055
function isToolResultObject(value: unknown): value is ToolResultObject {
1056-
return (
1057-
typeof value === "object" &&
1058-
value !== null &&
1059-
"textResultForLlm" in value &&
1060-
typeof (value as ToolResultObject).textResultForLlm === "string" &&
1061-
"resultType" in value &&
1062-
typeof (value as ToolResultObject).resultType === "string"
1063-
);
1056+
if (typeof value !== "object" || value === null) {
1057+
return false;
1058+
}
1059+
1060+
if (!("textResultForLlm" in value) || typeof (value as ToolResultObject).textResultForLlm !== "string") {
1061+
return false;
1062+
}
1063+
1064+
if (!("resultType" in value) || typeof (value as ToolResultObject).resultType !== "string") {
1065+
return false;
1066+
}
1067+
1068+
const allowedResultTypes: Array<ToolResultObject["resultType"]> = [
1069+
"success",
1070+
"failure",
1071+
"rejected",
1072+
"denied",
1073+
];
1074+
1075+
return allowedResultTypes.includes((value as ToolResultObject).resultType);
10641076
}

python/copilot/session.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
SessionEventType,
3939
session_event_from_dict,
4040
)
41-
from .tools import TOOL_EXCEPTION_TEXT, Tool, ToolHandler, ToolInvocation, ToolResult
41+
from .tools import Tool, ToolHandler, ToolInvocation, ToolResult
4242

4343
# Re-export SessionEvent under an alias used internally
4444
SessionEventTypeAlias = SessionEvent
@@ -948,11 +948,7 @@ async def _execute_tool_and_respond(
948948
# sent via the top-level error param so the CLI formats them with its
949949
# standard "Failed to execute..." message. Deliberate user-returned
950950
# failures send the full structured result to preserve metadata.
951-
if (
952-
tool_result.result_type == "failure"
953-
and tool_result.error
954-
and tool_result.text_result_for_llm == TOOL_EXCEPTION_TEXT
955-
):
951+
if getattr(tool_result, "_from_exception", False):
956952
await self.rpc.tools.handle_pending_tool_call(
957953
SessionToolsHandlePendingToolCallParams(
958954
request_id=request_id,

python/copilot/tools.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717

1818
ToolResultType = Literal["success", "failure", "rejected", "denied"]
1919

20-
# Constant used by define_tool's exception handler so that
21-
# _execute_tool_and_respond can detect exception-originated failures
22-
# and send them via the top-level error param (matching CLI formatting).
23-
TOOL_EXCEPTION_TEXT = "Invoking this tool produced an error. Detailed information is not available."
24-
2520

2621
@dataclass
2722
class ToolBinaryResult:
@@ -199,12 +194,20 @@ async def wrapped_handler(invocation: ToolInvocation) -> ToolResult:
199194
except Exception as exc:
200195
# Don't expose detailed error information to the LLM for security reasons.
201196
# The actual error is stored in the 'error' field for debugging.
202-
return ToolResult(
203-
text_result_for_llm=TOOL_EXCEPTION_TEXT,
197+
tr = ToolResult(
198+
text_result_for_llm=(
199+
"Invoking this tool produced an error. "
200+
"Detailed information is not available."
201+
),
204202
result_type="failure",
205203
error=str(exc),
206204
tool_telemetry={},
207205
)
206+
# Mark as exception-originated so _execute_tool_and_respond
207+
# sends it via the top-level error param (matching CLI formatting)
208+
# rather than as a structured result.
209+
tr._from_exception = True # type: ignore[attr-defined]
210+
return tr
208211

209212
return Tool(
210213
name=tool_name,

python/e2e/test_tool_results.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ def get_weather(params: WeatherParams, invocation: ToolInvocation) -> ToolResult
3030
on_permission_request=PermissionHandler.approve_all, tools=[get_weather]
3131
)
3232

33-
await session.send("What's the weather in Paris?")
34-
assistant_message = await get_final_assistant_message(session)
35-
assert (
36-
"sunny" in assistant_message.data.content.lower()
37-
or "72" in assistant_message.data.content
38-
)
33+
try:
34+
await session.send("What's the weather in Paris?")
35+
assistant_message = await get_final_assistant_message(session)
36+
assert (
37+
"sunny" in assistant_message.data.content.lower()
38+
or "72" in assistant_message.data.content
39+
)
40+
finally:
41+
await session.disconnect()
3942

4043
async def test_should_handle_tool_result_with_failure_resulttype(self, ctx: E2ETestContext):
4144
@define_tool("check_status", description="Checks the status of a service")
@@ -50,12 +53,15 @@ def check_status(invocation: ToolInvocation) -> ToolResult:
5053
on_permission_request=PermissionHandler.approve_all, tools=[check_status]
5154
)
5255

53-
answer = await session.send_and_wait(
54-
"Check the status of the service using check_status."
55-
" If it fails, say 'service is down'."
56-
)
57-
assert answer is not None
58-
assert "service is down" in answer.data.content.lower()
56+
try:
57+
answer = await session.send_and_wait(
58+
"Check the status of the service using check_status."
59+
" If it fails, say 'service is down'."
60+
)
61+
assert answer is not None
62+
assert "service is down" in answer.data.content.lower()
63+
finally:
64+
await session.disconnect()
5965

6066
async def test_should_preserve_tooltelemetry_and_not_stringify_structured_results_for_llm(
6167
self, ctx: E2ETestContext
@@ -78,14 +84,19 @@ def analyze_code(params: AnalyzeParams, invocation: ToolInvocation) -> ToolResul
7884
on_permission_request=PermissionHandler.approve_all, tools=[analyze_code]
7985
)
8086

81-
await session.send("Analyze the file main.ts for issues.")
82-
assistant_message = await get_final_assistant_message(session)
83-
assert "no issues" in assistant_message.data.content.lower()
84-
85-
# Verify the LLM received just textResultForLlm, not stringified JSON
86-
traffic = await ctx.get_exchanges()
87-
last_conversation = traffic[-1]
88-
tool_results = [m for m in last_conversation["request"]["messages"] if m["role"] == "tool"]
89-
assert len(tool_results) == 1
90-
assert "toolTelemetry" not in tool_results[0]["content"]
91-
assert "resultType" not in tool_results[0]["content"]
87+
try:
88+
await session.send("Analyze the file main.ts for issues.")
89+
assistant_message = await get_final_assistant_message(session)
90+
assert "no issues" in assistant_message.data.content.lower()
91+
92+
# Verify the LLM received just textResultForLlm, not stringified JSON
93+
traffic = await ctx.get_exchanges()
94+
last_conversation = traffic[-1]
95+
tool_results = [
96+
m for m in last_conversation["request"]["messages"] if m["role"] == "tool"
97+
]
98+
assert len(tool_results) == 1
99+
assert "toolTelemetry" not in tool_results[0]["content"]
100+
assert "resultType" not in tool_results[0]["content"]
101+
finally:
102+
await session.disconnect()

0 commit comments

Comments
 (0)