Skip to content
66 changes: 65 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .._async import run_async
from ..event_loop._retry import ModelRetryStrategy
from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle
from ..experimental.checkpoint import Checkpoint
from ..tools._tool_helpers import generate_missing_tool_result_content
from ..types._snapshot import (
SNAPSHOT_SCHEMA_VERSION,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW,
checkpointing: bool = False,
Comment thread
JackYPCOnline marked this conversation as resolved.
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -214,6 +216,11 @@ def __init__(
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations.
Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided
only for advanced use cases where the caller understands the risks.
checkpointing: When True, the event loop pauses at cycle boundaries (after model call,
after all tools execute) and returns an AgentResult with stop_reason="checkpoint"
and a populated ``checkpoint`` field. Persist the checkpoint and resume by passing
``[{"checkpointResume": {"checkpoint": checkpoint.to_dict()}}]`` as the next prompt.
Defaults to False. See :mod:`strands.experimental.checkpoint` for usage and limitations.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -304,6 +311,10 @@ def __init__(

self._interrupt_state = _InterruptState()

# Checkpointing: when True, event loop pauses at cycle boundaries
self._checkpointing: bool = checkpointing
self._checkpoint_resume_context: Checkpoint | None = None

# Runtime state for model providers (e.g., server-side response ids)
self._model_state: dict[str, Any] = {}

Expand Down Expand Up @@ -374,12 +385,18 @@ def cancel(self) -> None:
This method is thread-safe and can be called from any context
(e.g., another thread, web request handler, background task).

The agent will stop gracefully at the next checkpoint:
The agent will stop gracefully at the next cancellation-safe point:
- During model response streaming
- Before tool execution

The agent will return a result with stop_reason="cancelled".

Note:
"Cancellation-safe point" is distinct from
:class:`~strands.experimental.checkpoint.Checkpoint` boundaries.
Cancel takes precedence: a cancel signal at either checkpoint boundary
surfaces as ``stop_reason="cancelled"``, not ``"checkpoint"``.

Example:
```python
agent = Agent(model=model)
Expand Down Expand Up @@ -1006,10 +1023,57 @@ async def _execute_event_loop_cycle(
if structured_output_context:
structured_output_context.cleanup(self.tool_registry)

def _try_consume_checkpoint_resume(self, prompt: list[Any]) -> bool:
"""Detect, validate, and consume a ``checkpointResume`` prompt block.

Returns True if the prompt was a resume block (state restored, caller
should skip normal message conversion). Returns False if no resume
block is present. Raises on malformed or misconfigured input.

Follows interrupt-resume error conventions: TypeError for shape issues,
KeyError for missing keys, ValueError for misconfig. A schema mismatch
in the checkpoint payload raises ``CheckpointException``.
"""
has_checkpoint_resume = any(isinstance(content, dict) and "checkpointResume" in content for content in prompt)
if not has_checkpoint_resume:
return False

if not self._checkpointing:
raise ValueError(
"Received checkpointResume block but agent was created with "
"checkpointing=False. Pass checkpointing=True when constructing "
"the Agent to enable durable execution."
)

invalid_types = [
key for content in prompt if isinstance(content, dict) for key in content if key != "checkpointResume"
]
if invalid_types:
raise TypeError(
f"content_types=<{invalid_types}> | checkpointResume cannot be mixed with other content types"
)

if len(prompt) != 1:
raise TypeError(f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt")

resume_block = prompt[0].get("checkpointResume", {})
if not isinstance(resume_block, dict) or "checkpoint" not in resume_block:
raise KeyError("checkpoint | missing required key in checkpointResume block")

checkpoint = Checkpoint.from_dict(resume_block["checkpoint"])
self.load_snapshot(Snapshot.from_dict(checkpoint.snapshot))
self._checkpoint_resume_context = checkpoint
return True

async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if self._interrupt_state.activated:
return []

# Resume detection — must run before existing shape handling so checkpointResume
# blocks aren't misinterpreted as content blocks.
if isinstance(prompt, list) and prompt and self._try_consume_checkpoint_resume(prompt):
return []

messages: Messages | None = None
if prompt is not None:
# Check if the latest message is toolUse
Expand Down
18 changes: 16 additions & 2 deletions src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import BaseModel

from ..experimental.checkpoint import Checkpoint
from ..interrupt import Interrupt
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message
Expand All @@ -26,6 +27,9 @@ class AgentResult:
state: Additional state information from the event loop.
interrupts: List of interrupts if raised by user.
structured_output: Parsed structured output when structured_output_model was specified.
checkpoint: Checkpoint captured when the agent paused for durable execution.
Populated only when stop_reason == "checkpoint". See
strands.experimental.checkpoint for usage.
"""

stop_reason: StopReason
Expand All @@ -34,6 +38,7 @@ class AgentResult:
state: Any
interrupts: Sequence[Interrupt] | None = None
structured_output: BaseModel | None = None
checkpoint: Checkpoint | None = None

@property
def context_size(self) -> int | None:
Expand Down Expand Up @@ -94,15 +99,23 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
Returns:
AgentResult instance
Raises:
TypeError: If the data format is invalid@
TypeError: If the data format is invalid
"""
if data.get("type") != "agent_result":
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")

message = cast(Message, data.get("message"))
stop_reason = cast(StopReason, data.get("stop_reason"))
checkpoint_data = data.get("checkpoint")
checkpoint = Checkpoint.from_dict(checkpoint_data) if checkpoint_data else None

return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})
return cls(
message=message,
stop_reason=stop_reason,
metrics=EventLoopMetrics(),
state={},
checkpoint=checkpoint,
)

def to_dict(self) -> dict[str, Any]:
"""Convert this AgentResult to JSON-serializable dictionary.
Expand All @@ -114,4 +127,5 @@ def to_dict(self) -> dict[str, Any]:
"type": "agent_result",
"message": self.message,
"stop_reason": self.stop_reason,
"checkpoint": self.checkpoint.to_dict() if self.checkpoint else None,
}
89 changes: 87 additions & 2 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from opentelemetry import trace as trace_api

from ..experimental.checkpoint import Checkpoint, CheckpointPosition
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
Expand Down Expand Up @@ -75,6 +76,27 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool:
return False


def _build_checkpoint_stop_event(
agent: "Agent",
position: "CheckpointPosition",
cycle_index: int,
message: Message,
request_state: Any,
) -> EventLoopStopEvent:
"""Build a checkpoint stop event. Used at ``after_model`` and ``after_tools``."""
checkpoint = Checkpoint(
position=position,
cycle_index=cycle_index,
snapshot=agent.take_snapshot(preset="session").to_dict(),
)
return EventLoopStopEvent(
"checkpoint",
message,
agent.event_loop_metrics,
request_state,
checkpoint=checkpoint,
)

async def _estimate_input_tokens(agent: "Agent") -> int:
"""Estimate the input token count for the next model call.

Expand Down Expand Up @@ -145,12 +167,16 @@ async def event_loop_cycle(
structured_output_context: Optional context for structured output management.

Yields:
Model and tool stream events. The last event is a tuple containing:
Model and tool stream events. The final ``EventLoopStopEvent`` payload
(``event["stop"]``) is a 7-tuple:

- StopReason: Reason the model stopped generating (e.g., "tool_use")
- StopReason: Reason the model stopped generating (e.g., "tool_use", "checkpoint")
- Message: The generated message from the model
- EventLoopMetrics: Updated metrics for the event loop
- Any: Updated request state
- Sequence[Interrupt] | None: Interrupts raised during the cycle, if any
- BaseModel | None: Structured output result, if any
- Checkpoint | None: Checkpoint captured when stop_reason == "checkpoint"

Raises:
EventLoopException: If an error occurs during execution
Expand All @@ -164,6 +190,18 @@ async def event_loop_cycle(
# Initialize state and get cycle trace
if "request_state" not in invocation_state:
invocation_state["request_state"] = {}

# Consume checkpoint resume context (one-shot).
resume_context = agent._checkpoint_resume_context
if resume_context is not None:
agent._checkpoint_resume_context = None
# after_tools completed that cycle, so the next cycle starts at +1
next_cycle = (
resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index
)
invocation_state["_checkpoint_cycle_index"] = next_cycle
invocation_state["_checkpoint_resume_position"] = resume_context.position

attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))}
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
invocation_state["event_loop_cycle_trace"] = cycle_trace
Expand Down Expand Up @@ -223,6 +261,25 @@ async def event_loop_cycle(
)

if stop_reason == "tool_use":
# Checkpoint after model call, before tools. Cancel takes precedence.
if agent._checkpointing and not agent._cancel_signal.is_set():
resume_position = invocation_state.pop("_checkpoint_resume_position", None)
if resume_position == "after_model":
pass # Just resumed here — skip re-checkpoint
else:
cycle_index = invocation_state.get("_checkpoint_cycle_index", 0)
Comment thread
JackYPCOnline marked this conversation as resolved.
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)
yield _build_checkpoint_stop_event(
agent=agent,
position="after_model",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# Handle tool execution
tool_events = _handle_tool_execution(
stop_reason,
Expand Down Expand Up @@ -640,6 +697,34 @@ async def _handle_tool_execution(
)
return

# Checkpoint after all tools complete, before the next model call.
# Only emitted on tool_use cycles; end_turn on the first call completes
# normally with no checkpoint. Cancel takes precedence.
if agent._checkpointing and not agent._cancel_signal.is_set():
cycle_index = invocation_state.get("_checkpoint_cycle_index", 0)
invocation_state["_checkpoint_cycle_index"] = cycle_index + 1
yield _build_checkpoint_stop_event(
agent=agent,
position="after_tools",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# Checkpointing-only: if cancel suppressed the after_tools checkpoint above,
# surface it as "cancelled" now rather than recursing into another model call
# that would also cancel. Non-checkpointing callers fall through to
# recurse_event_loop so the existing cancel-during-model-stream path handles it.
if agent._checkpointing and agent._cancel_signal.is_set():
yield EventLoopStopEvent(
"cancelled",
message,
agent.event_loop_metrics,
invocation_state["request_state"],
)
return

events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
Expand Down
Loading
Loading