Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ListToolsRequest,
)

from .._errors import ProcessError
from ..types import (
PermissionMode,
PermissionResultAllow,
Expand Down Expand Up @@ -128,6 +129,7 @@

# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()
self._got_error_result = False

# SessionStore mirroring (set via set_transcript_mirror_batcher)
self._transcript_mirror_batcher: TranscriptMirrorBatcher | None = None
Expand Down Expand Up @@ -294,6 +296,7 @@
if self._transcript_mirror_batcher is not None:
await self._transcript_mirror_batcher.flush()
self._first_result_event.set()
self._got_error_result = bool(message.get("is_error"))

Check failure on line 299 in src/claude_agent_sdk/_internal/query.py

View check run for this annotation

Claude / Claude Code Review

_got_error_result not reset on new turn: mid-turn crash silently swallowed

After the review fix, `_got_error_result` tracks the *most recent* result rather than latching forever — but it is still only updated when a `result` arrives, never reset when a *new turn begins*. So in a long-lived `ClaudeSDKClient` session: turn N ends with `is_error=True` (CLI stays alive in streaming mode) → user calls `client.query(...)` for turn N+1 → CLI crashes mid-turn-N+1 *before* emitting that turn's result → the suppression check is still `True`, the `ProcessError` is logged at debug
Comment thread
claude[bot] marked this conversation as resolved.
Outdated

# Regular SDK messages go to the stream
await self._message_send.send(message)
Expand All @@ -303,14 +306,25 @@
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Signal all pending control requests so they fail fast instead of timing out
for request_id, event in list(self.pending_control_responses.items()):
if request_id not in self.pending_control_results:
self.pending_control_results[request_id] = e
event.set()
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
if isinstance(e, ProcessError) and self._got_error_result:
# CLI exits non-zero after emitting an error result
# (error_max_turns, error_during_execution, ...). The consumer
# already received the structured ResultMessage; don't follow
# it with a redundant bare Exception.
logger.debug(
"CLI exited with code %s after error result; "
"treating as clean termination",
e.exit_code,
)
else:
logger.error(f"Fatal error in message reader: {e}")
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
finally:
# Flush any remaining transcript mirror entries before closing so
# an early stdout EOF or transport error doesn't drop entries
Expand Down
191 changes: 191 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
query,
tool,
)
from claude_agent_sdk._errors import ProcessError
from claude_agent_sdk._internal.query import Query
from claude_agent_sdk.types import HookMatcher

Expand Down Expand Up @@ -949,3 +950,193 @@ async def _test():
assert "fast_1" not in q._inflight_requests

asyncio.run(_test())


class TestProcessExitAfterErrorResult:
"""Regression tests for #913: when the CLI emits a result message with
is_error=True (e.g. subtype=error_max_turns) and then exits non-zero,
the SDK should treat that as clean termination — the consumer already
received the structured ResultMessage and shouldn't see a redundant
bare Exception."""

def _make_transport_then_raise(self, messages, exc):
mock_transport = AsyncMock()

async def mock_receive():
for msg in messages:
yield msg
raise exc

mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
return mock_transport

def test_process_error_after_error_result_is_suppressed(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[
{
"type": "result",
"subtype": "error_max_turns",
"is_error": True,
"num_turns": 60,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
}
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 1
assert received[0]["subtype"] == "error_max_turns"

anyio.run(_test)

def test_process_error_without_result_still_raises(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

with pytest.raises(Exception, match="Command failed"):
async for _ in q.receive_messages():
pass
await q.close()

anyio.run(_test)

def test_process_error_after_success_result_still_raises(self):
async def _test():
transport = self._make_transport_then_raise(
messages=[
{
"type": "result",
"subtype": "success",
"is_error": False,
"num_turns": 1,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
}
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(Exception, match="Command failed"):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 1
assert received[0]["subtype"] == "success"

anyio.run(_test)

def test_process_error_after_error_then_success_result_still_raises(self):
"""The flag tracks the *most recent* result, not a sticky latch."""

async def _test():
transport = self._make_transport_then_raise(
messages=[
{
"type": "result",
"subtype": "error_during_execution",
"is_error": True,
"num_turns": 1,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
},
{
"type": "result",
"subtype": "success",
"is_error": False,
"num_turns": 2,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
},
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)
await q.start()

received = []
with pytest.raises(Exception, match="Command failed"):
async for msg in q.receive_messages():
received.append(msg)
await q.close()

assert len(received) == 2

anyio.run(_test)

def test_pending_control_requests_fail_fast_on_suppressed_exit(self):
"""Even when the ProcessError is suppressed for the message stream,
in-flight control requests must still fail fast (process is dead;
no control_response will ever arrive)."""

async def _test():
transport = self._make_transport_then_raise(
messages=[
{
"type": "result",
"subtype": "error_max_turns",
"is_error": True,
"num_turns": 1,
"session_id": "s",
"duration_ms": 1,
"duration_api_ms": 1,
"total_cost_usd": 0.0,
}
],
exc=ProcessError(
"Command failed with exit code 1", exit_code=1, stderr=""
),
)
q = Query(transport=transport, is_streaming_mode=True)

# Register a pending control request before the read loop runs.
event = anyio.Event()
q.pending_control_responses["req_1"] = event

await q.start()
async for _ in q.receive_messages():
pass
await q.close()

assert event.is_set()
assert isinstance(q.pending_control_results["req_1"], ProcessError)

anyio.run(_test)
Loading