Skip to content

Commit 73fe9cc

Browse files
authored
feat: add CancellationToken for graceful agent execution cancellation (#1772)
1 parent 98636ae commit 73fe9cc

9 files changed

Lines changed: 554 additions & 4 deletions

File tree

src/strands/agent/agent.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def __init__(
240240
self.record_direct_tool_call = record_direct_tool_call
241241
self.load_tools_from_directory = load_tools_from_directory
242242

243+
# Create internal cancel signal for graceful cancellation using threading.Event
244+
self._cancel_signal = threading.Event()
245+
243246
self.tool_registry = ToolRegistry()
244247

245248
# Process tool list if provided
@@ -327,6 +330,37 @@ def __init__(
327330

328331
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
329332

333+
def cancel(self) -> None:
334+
"""Cancel the currently running agent invocation.
335+
336+
This method is thread-safe and can be called from any context
337+
(e.g., another thread, web request handler, background task).
338+
339+
The agent will stop gracefully at the next checkpoint:
340+
- During model response streaming
341+
- Before tool execution
342+
343+
The agent will return a result with stop_reason="cancelled".
344+
345+
Example:
346+
```python
347+
agent = Agent(model=model)
348+
349+
# Start agent in background
350+
task = asyncio.create_task(agent.invoke_async("Hello"))
351+
352+
# Cancel from another context
353+
agent.cancel()
354+
355+
result = await task
356+
assert result.stop_reason == "cancelled"
357+
```
358+
359+
Note:
360+
Multiple calls to cancel() are safe and idempotent.
361+
"""
362+
self._cancel_signal.set()
363+
330364
@property
331365
def system_prompt(self) -> str | None:
332366
"""Get the system prompt as a string for backwards compatibility.
@@ -756,6 +790,9 @@ async def stream_async(
756790
raise
757791

758792
finally:
793+
# Clear cancel signal to allow agent reuse after cancellation
794+
self._cancel_signal.clear()
795+
759796
if self._invocation_lock.locked():
760797
self._invocation_lock.release()
761798

src/strands/event_loop/event_loop.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ async def _handle_model_execution(
336336
system_prompt_content=agent._system_prompt_content,
337337
tool_choice=structured_output_context.tool_choice,
338338
invocation_state=invocation_state,
339+
cancel_signal=agent._cancel_signal,
339340
):
340341
yield event
341342

@@ -465,6 +466,47 @@ async def _handle_tool_execution(
465466
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids]
466467

467468
interrupts = []
469+
470+
# Check for cancellation before tool execution
471+
# Add tool_result for each tool_use to maintain valid conversation state
472+
if agent._cancel_signal.is_set():
473+
logger.debug("tool_count=<%d> | cancellation detected before tool execution", len(tool_uses))
474+
475+
# Create cancellation tool_result for each tool_use to avoid invalid message state
476+
# (tool_use without tool_result would be rejected on next invocation)
477+
for tool_use in tool_uses:
478+
cancel_result: ToolResult = {
479+
"toolUseId": str(tool_use.get("toolUseId")),
480+
"status": "error",
481+
"content": [{"text": "Tool execution cancelled"}],
482+
}
483+
tool_results.append(cancel_result)
484+
485+
# Add tool results message to conversation if any tools were cancelled
486+
cancelled_tool_result_message: Message | None = None
487+
if tool_results:
488+
_cancelled_msg: Message = {
489+
"role": "user",
490+
"content": [{"toolResult": result} for result in tool_results],
491+
}
492+
cancelled_tool_result_message = _cancelled_msg
493+
agent.messages.append(_cancelled_msg)
494+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=_cancelled_msg))
495+
yield ToolResultMessageEvent(message=_cancelled_msg)
496+
497+
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
498+
yield EventLoopStopEvent(
499+
"cancelled",
500+
message,
501+
agent.event_loop_metrics,
502+
invocation_state["request_state"],
503+
)
504+
if cycle_span:
505+
tracer.end_event_loop_cycle_span(
506+
span=cycle_span, message=message, tool_result_message=cancelled_tool_result_message
507+
)
508+
return
509+
468510
tool_events = agent.tool_executor._execute(
469511
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context
470512
)

src/strands/event_loop/streaming.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
import threading
56
import time
67
import warnings
78
from collections.abc import AsyncGenerator, AsyncIterable
@@ -368,13 +369,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non
368369

369370

370371
async def process_stream(
371-
chunks: AsyncIterable[StreamEvent], start_time: float | None = None
372+
chunks: AsyncIterable[StreamEvent],
373+
start_time: float | None = None,
374+
cancel_signal: threading.Event | None = None,
372375
) -> AsyncGenerator[TypedEvent, None]:
373376
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
374377
375378
Args:
376379
chunks: The chunks of the response stream from the model.
377380
start_time: Time when the model request is initiated
381+
cancel_signal: Optional threading.Event to check for cancellation during streaming.
378382
379383
Yields:
380384
The reason for stopping, the constructed message, and the usage metrics.
@@ -395,6 +399,19 @@ async def process_stream(
395399
metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0)
396400

397401
async for chunk in chunks:
402+
# Check for cancellation during stream processing
403+
if cancel_signal and cancel_signal.is_set():
404+
logger.debug("cancellation detected during stream processing")
405+
# Return cancelled stop reason with cancellation message
406+
# The incomplete message in state["message"] is discarded and never added to agent.messages
407+
yield ModelStopReason(
408+
stop_reason="cancelled",
409+
message={"role": "assistant", "content": [{"text": "Cancelled by user"}]},
410+
usage=usage,
411+
metrics=metrics,
412+
)
413+
return
414+
398415
# Track first byte time when we get first content
399416
if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk):
400417
first_byte_time = time.time()
@@ -431,6 +448,7 @@ async def stream_messages(
431448
tool_choice: Any | None = None,
432449
system_prompt_content: list[SystemContentBlock] | None = None,
433450
invocation_state: dict[str, Any] | None = None,
451+
cancel_signal: threading.Event | None = None,
434452
**kwargs: Any,
435453
) -> AsyncGenerator[TypedEvent, None]:
436454
"""Streams messages to the model and processes the response.
@@ -444,6 +462,7 @@ async def stream_messages(
444462
system_prompt_content: The authoritative system prompt content blocks that always contains the
445463
system prompt data.
446464
invocation_state: Caller-provided state/context that was passed to the agent when it was invoked.
465+
cancel_signal: Optional threading.Event to check for cancellation during streaming.
447466
**kwargs: Additional keyword arguments for future extensibility.
448467
449468
Yields:
@@ -463,5 +482,5 @@ async def stream_messages(
463482
invocation_state=invocation_state,
464483
)
465484

466-
async for event in process_stream(chunks, start_time):
485+
async for event in process_stream(chunks, start_time, cancel_signal):
467486
yield event

src/strands/session/repository_session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
124124
else:
125125
state_changed = current_state_version != last_synced.get("state_version")
126126
internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version")
127-
conversation_manager_state_changed = (
128-
current_conversation_manager_state != last_synced.get("conversation_manager_state")
127+
conversation_manager_state_changed = current_conversation_manager_state != last_synced.get(
128+
"conversation_manager_state"
129129
)
130130

131131
if not state_changed and not internal_state_changed and not conversation_manager_state_changed:

src/strands/types/event_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Metrics(TypedDict, total=False):
3737

3838

3939
StopReason = Literal[
40+
"cancelled",
4041
"content_filtered",
4142
"end_turn",
4243
"guardrail_intervened",
@@ -47,6 +48,7 @@ class Metrics(TypedDict, total=False):
4748
]
4849
"""Reason for the model ending its response generation.
4950
51+
- "cancelled": Agent execution was cancelled via agent.cancel()
5052
- "content_filtered": Content was filtered due to policy violation
5153
- "end_turn": Normal completion of the response
5254
- "guardrail_intervened": Guardrail system intervened

0 commit comments

Comments
 (0)