2525from askui .speaker .speaker import SpeakerResult , Speakers
2626
2727if TYPE_CHECKING :
28+ from askui .models .shared .conversation_callback import ConversationCallback
2829 from askui .utils .caching .cache_manager import CacheManager
2930
3031logger = logging .getLogger (__name__ )
@@ -58,6 +59,7 @@ class Conversation:
5859 reporter: Reporter for logging messages and actions
5960 cache_manager: Cache manager for recording/playback (optional)
6061 truncation_strategy_factory: Factory for creating truncation strategies
62+ callbacks: List of callbacks for conversation lifecycle hooks (optional)
6163 """
6264
6365 def __init__ (
@@ -69,6 +71,7 @@ def __init__(
6971 reporter : Reporter = NULL_REPORTER ,
7072 cache_manager : "CacheManager | None" = None ,
7173 truncation_strategy_factory : TruncationStrategyFactory | None = None ,
74+ callbacks : "list[ConversationCallback] | None" = None ,
7275 ) -> None :
7376 """Initialize conversation with speakers and model providers."""
7477 if not speakers :
@@ -92,18 +95,33 @@ def __init__(
9295 truncation_strategy_factory or SimpleTruncationStrategyFactory ()
9396 )
9497 self ._truncation_strategy : TruncationStrategy | None = None
98+ self ._callbacks : "list[ConversationCallback]" = callbacks or []
9599
96100 # State for current execution (set in start())
97101 self .settings : ActSettings = ActSettings ()
98102 self .tools : ToolCollection = ToolCollection ()
99103 self ._reporters : list [Reporter ] = []
104+ self ._step_index : int = 0
100105
101106 # Cache execution context (for communication between tools and CacheExecutor)
102107 self .cache_execution_context : dict [str , Any ] = {}
103108
104109 # Track if cache execution was used (to prevent recording during playback)
105110 self ._executed_from_cache : bool = False
106111
112+ def _call_callbacks (self , method_name : str , * args : Any , ** kwargs : Any ) -> None :
113+ """Call a method on all registered callbacks.
114+
115+ Args:
116+ method_name: Name of the callback method to call
117+ *args: Positional arguments to pass to the callback
118+ **kwargs: Keyword arguments to pass to the callback
119+ """
120+ for callback in self ._callbacks :
121+ method = getattr (callback , method_name , None )
122+ if method and callable (method ):
123+ method (self , * args , ** kwargs )
124+
107125 @tracer .start_as_current_span ("conversation" )
108126 def execute_conversation (
109127 self ,
@@ -119,7 +137,6 @@ def execute_conversation(
119137
120138 Args:
121139 messages: Initial message history
122- on_message: Optional callback for each message
123140 tools: Available tools
124141 settings: Agent settings
125142 reporters: Optional list of additional reporters for this conversation
@@ -128,7 +145,11 @@ def execute_conversation(
128145 logger .info (msg )
129146
130147 self ._setup_control_loop (messages , tools , settings , reporters )
148+
149+ self ._call_callbacks ("on_conversation_start" )
131150 self ._execute_control_loop ()
151+ self ._call_callbacks ("on_conversation_end" )
152+
132153 self ._conclude_control_loop ()
133154
134155 @tracer .start_as_current_span ("setup_control_loop" )
@@ -162,9 +183,12 @@ def _setup_control_loop(
162183
163184 @tracer .start_as_current_span ("control_loop" )
164185 def _execute_control_loop (self ) -> None :
186+ self ._call_callbacks ("on_control_loop_start" )
187+ self ._step_index = 0
165188 continue_execution = True
166189 while continue_execution :
167190 continue_execution = self ._execute_step ()
191+ self ._call_callbacks ("on_control_loop_end" )
168192
169193 @tracer .start_as_current_span ("finish_control_loop" )
170194 def _conclude_control_loop (self ) -> None :
@@ -189,6 +213,7 @@ def _execute_step(self) -> bool:
189213 Returns:
190214 True if loop should continue, False if done
191215 """
216+ self ._call_callbacks ("on_step_start" , self ._step_index )
192217
193218 # 1. Infer next speaker
194219 speaker = self .current_speaker
@@ -226,6 +251,9 @@ def _execute_step(self) -> bool:
226251 if result .usage :
227252 self ._accumulate_usage (result .usage )
228253
254+ self ._call_callbacks ("on_step_end" , self ._step_index , result )
255+ self ._step_index += 1
256+
229257 return continue_loop
230258
231259 @tracer .start_as_current_span ("execute_tool_call" )
@@ -255,8 +283,11 @@ def _execute_tools_if_present(self, message: MessageParam) -> MessageParam | Non
255283 return None
256284
257285 # Execute tools
286+ tool_names = [block .name for block in tool_use_blocks ]
258287 logger .debug ("Executing %d tool(s)" , len (tool_use_blocks ))
288+ self ._call_callbacks ("on_tool_execution_start" , tool_names )
259289 tool_results = self .tools .run (tool_use_blocks )
290+ self ._call_callbacks ("on_tool_execution_end" , tool_names )
260291
261292 if not tool_results :
262293 return None
0 commit comments