22
33import json
44import logging
5+ import threading
56import time
67import warnings
78from collections .abc import AsyncGenerator , AsyncIterable
@@ -368,13 +369,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non
368369
369370
370371async def process_stream (
371- chunks : AsyncIterable [StreamEvent ], start_time : float | None = None
372+ chunks : AsyncIterable [StreamEvent ],
373+ start_time : float | None = None ,
374+ cancel_signal : threading .Event | None = None ,
372375) -> AsyncGenerator [TypedEvent , None ]:
373376 """Processes the response stream from the API, constructing the final message and extracting usage metrics.
374377
375378 Args:
376379 chunks: The chunks of the response stream from the model.
377380 start_time: Time when the model request is initiated
381+ cancel_signal: Optional threading.Event to check for cancellation during streaming.
378382
379383 Yields:
380384 The reason for stopping, the constructed message, and the usage metrics.
@@ -395,6 +399,19 @@ async def process_stream(
395399 metrics : Metrics = Metrics (latencyMs = 0 , timeToFirstByteMs = 0 )
396400
397401 async for chunk in chunks :
402+ # Check for cancellation during stream processing
403+ if cancel_signal and cancel_signal .is_set ():
404+ logger .debug ("cancellation detected during stream processing" )
405+ # Return cancelled stop reason with cancellation message
406+ # The incomplete message in state["message"] is discarded and never added to agent.messages
407+ yield ModelStopReason (
408+ stop_reason = "cancelled" ,
409+ message = {"role" : "assistant" , "content" : [{"text" : "Cancelled by user" }]},
410+ usage = usage ,
411+ metrics = metrics ,
412+ )
413+ return
414+
398415 # Track first byte time when we get first content
399416 if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk ):
400417 first_byte_time = time .time ()
@@ -431,6 +448,7 @@ async def stream_messages(
431448 tool_choice : Any | None = None ,
432449 system_prompt_content : list [SystemContentBlock ] | None = None ,
433450 invocation_state : dict [str , Any ] | None = None ,
451+ cancel_signal : threading .Event | None = None ,
434452 ** kwargs : Any ,
435453) -> AsyncGenerator [TypedEvent , None ]:
436454 """Streams messages to the model and processes the response.
@@ -444,6 +462,7 @@ async def stream_messages(
444462 system_prompt_content: The authoritative system prompt content blocks that always contains the
445463 system prompt data.
446464 invocation_state: Caller-provided state/context that was passed to the agent when it was invoked.
465+ cancel_signal: Optional threading.Event to check for cancellation during streaming.
447466 **kwargs: Additional keyword arguments for future extensibility.
448467
449468 Yields:
@@ -463,5 +482,5 @@ async def stream_messages(
463482 invocation_state = invocation_state ,
464483 )
465484
466- async for event in process_stream (chunks , start_time ):
485+ async for event in process_stream (chunks , start_time , cancel_signal ):
467486 yield event
0 commit comments