Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ListToolsRequest,
)

from .._errors import ProcessError
from ..types import (
PermissionMode,
PermissionResultAllow,
Expand Down Expand Up @@ -128,6 +129,11 @@ def __init__(

# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()
# Set to the result's error text when the most recent message is a
# result with is_error=True. Used to replace the generic "exit code 1"
# ProcessError with the structured error the CLI already reported.
# Mirrors the TypeScript SDK's `lastErrorResultText` (Query.ts).
self._last_error_result_text: str | None = None

# SessionStore mirroring (set via set_transcript_mirror_batcher)
self._transcript_mirror_batcher: TranscriptMirrorBatcher | None = None
Expand Down Expand Up @@ -294,6 +300,22 @@ async def _read_messages(self) -> None:
if self._transcript_mirror_batcher is not None:
await self._transcript_mirror_batcher.flush()
self._first_result_event.set()
if message.get("is_error"):
errors = message.get("errors") or []
self._last_error_result_text = "; ".join(errors) or str(
message.get("subtype", "unknown error")
)
else:
self._last_error_result_text = None
elif not (
msg_type == "system"
and message.get("subtype") == "session_state_changed"
):
# Anything other than the post-turn session_state_changed
# marker means the conversation moved on; a ProcessError
# now is a fresh crash, not the expected exit from a prior
# error result. Mirrors the TypeScript SDK's reset logic.
self._last_error_result_text = None

# Regular SDK messages go to the stream
await self._message_send.send(message)
Expand All @@ -303,14 +325,31 @@ async def _read_messages(self) -> None:
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Signal all pending control requests so they fail fast instead of timing out
for request_id, event in list(self.pending_control_responses.items()):
if request_id not in self.pending_control_results:
self.pending_control_results[request_id] = e
event.set()
# When the CLI emits a result with is_error=True (e.g.
# error_max_turns, error_during_execution) it then exits non-zero
# on purpose, for shell-script consumers. The trailing ProcessError
# carries no information beyond "exit code 1" — replace it with the
# structured error the CLI already reported so the exception is
# actionable. Mirrors the TypeScript SDK (Query.ts readMessages).
if isinstance(e, ProcessError) and self._last_error_result_text is not None:
error_text = (
f"Claude Code returned an error result: "
f"{self._last_error_result_text}"
)
logger.debug(
"Replacing ProcessError (exit code %s) with result error text",
e.exit_code,
)
else:
error_text = str(e)
logger.error(f"Fatal error in message reader: {e}")
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
await self._message_send.send({"type": "error", "error": error_text})
finally:
# Flush any remaining transcript mirror entries before closing so
# an early stdout EOF or transport error doesn't drop entries
Expand Down
305 changes: 305 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
query,
tool,
)
from claude_agent_sdk._errors import ProcessError
from claude_agent_sdk._internal.query import Query
from claude_agent_sdk.types import HookMatcher

Expand Down Expand Up @@ -949,3 +950,307 @@ async def _test():
assert "fast_1" not in q._inflight_requests

asyncio.run(_test())


class TestProcessExitAfterErrorResult:
"""Regression tests for #913: when the CLI emits a result message with
is_error=True (e.g. subtype=error_max_turns) and then exits non-zero,
the trailing ProcessError carries no information beyond "exit code 1".
Replace it with the structured error text the CLI already reported so
the exception is actionable. Mirrors the TypeScript SDK (Query.ts)."""

def _make_transport_then_raise(self, messages, exc):
mock_transport = AsyncMock()

async def mock_receive():
for msg in messages:
yield msg
raise exc

mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
return mock_transport

def _error_result(self, subtype="error_max_turns", errors=None, **overrides):
msg = {
"type": "result",
"subtype": subtype,
"is_error": True,
"num_turns": 1,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
}
if errors is not None:
msg["errors"] = errors
msg.update(overrides)
return msg

def test_process_error_after_error_result_uses_result_error_text(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[
self._error_result(
subtype="error_max_turns",
errors=["Reached maximum number of turns (60)"],
num_turns=60,
)
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(
Exception,
match=r"Claude Code returned an error result: "
r"Reached maximum number of turns \(60\)",
):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 1
assert received[0]["subtype"] == "error_max_turns"

anyio.run(_test)

def test_process_error_after_error_result_falls_back_to_subtype(self):
"""When the result has no errors[] (older CLI / minimal payload), the
improved message falls back to the subtype so it's still actionable."""

async def _test():
transport = self._make_transport_then_raise(
messages=[self._error_result(subtype="error_during_execution")],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

with pytest.raises(
Exception,
match=r"Claude Code returned an error result: error_during_execution",
):
async for _ in q.receive_messages():
pass
await q.close()

anyio.run(_test)

def test_process_error_after_error_result_joins_multiple_errors(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[
self._error_result(
subtype="error_during_execution",
errors=["tool timed out", "ENOENT: missing file"],
)
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

with pytest.raises(
Exception,
match=r"tool timed out; ENOENT: missing file",
):
async for _ in q.receive_messages():
pass
await q.close()

anyio.run(_test)

def test_process_error_without_result_keeps_original_message(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

with pytest.raises(Exception, match="Command failed"):
async for _ in q.receive_messages():
pass
await q.close()

anyio.run(_test)

def test_process_error_after_success_result_keeps_original_message(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[
{
"type": "result",
"subtype": "success",
"is_error": False,
"num_turns": 1,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
}
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(Exception, match="Command failed"):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 1
assert received[0]["subtype"] == "success"

anyio.run(_test)

def test_process_error_after_error_then_success_result_keeps_original(self):
"""Tracks the *most recent* result, not a sticky latch."""

async def _test():
transport = self._make_transport_then_raise(
messages=[
self._error_result(subtype="error_during_execution"),
{
"type": "result",
"subtype": "success",
"is_error": False,
"num_turns": 2,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
},
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(Exception, match="Command failed"):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 2

anyio.run(_test)

def test_session_state_changed_after_error_result_preserves_replacement(self):
"""The CLI emits a post-turn `system: session_state_changed(idle)`
marker after the result and before exit. It must not reset the
tracking flag — the conversation hasn't moved on."""

async def _test():
transport = self._make_transport_then_raise(
messages=[
self._error_result(
subtype="error_max_turns",
errors=["Reached maximum number of turns (10)"],
),
{
"type": "system",
"subtype": "session_state_changed",
"state": "idle",
"session_id": "s",
},
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

with pytest.raises(
Exception, match=r"Claude Code returned an error result"
):
async for _ in q.receive_messages():
pass
await q.close()

anyio.run(_test)

def test_new_turn_after_error_result_keeps_original_message(self):
"""A new user turn invalidates the 'expecting imminent exit' state from
a prior turn's error result; a crash mid-new-turn must surface as-is."""

async def _test():
transport = self._make_transport_then_raise(
messages=[
self._error_result(subtype="error_during_execution"),
{
"type": "user",
"message": {"role": "user", "content": "next turn"},
"session_id": "s",
},
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(Exception, match="Command failed"):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 2

anyio.run(_test)

def test_pending_control_requests_fail_fast_on_replaced_error(self):
"""In-flight control requests must still fail fast (process is dead;
no control_response will ever arrive) regardless of message replacement."""

async def _test():
transport = self._make_transport_then_raise(
messages=[self._error_result(subtype="error_max_turns")],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)

# Register a pending control request before the read loop runs.
event = anyio.Event()
q.pending_control_responses["req_1"] = event

await q.start()
with pytest.raises(
Exception, match=r"Claude Code returned an error result"
):
async for _ in q.receive_messages():
pass
await q.close()

assert event.is_set()
assert isinstance(q.pending_control_results["req_1"], ProcessError)

anyio.run(_test)
Loading