forked from strands-agents/sdk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbidirectional_event_loop.py
More file actions
480 lines (372 loc) · 18 KB
/
bidirectional_event_loop.py
File metadata and controls
480 lines (372 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""Bidirectional session management for concurrent streaming conversations.
Manages bidirectional communication sessions with concurrent processing of model events,
tool execution, and audio processing. Provides coordination between background tasks
while maintaining a simple interface for agent interaction.
Features:
- Concurrent task management for model events and tool execution
- Interruption handling with audio buffer clearing
- Tool execution with cancellation support
- Session lifecycle management
"""
import asyncio
import logging
import traceback
import uuid
from ....tools._validator import validate_and_prepare_tools
from ....telemetry.metrics import Trace
from ....types._events import ToolResultEvent, ToolStreamEvent
from ....types.content import Message
from ....types.tools import ToolResult, ToolUse
from ..models.bidirectional_model import BidirectionalModel
logger = logging.getLogger(__name__)
# Session constants
TOOL_QUEUE_TIMEOUT = 0.5
SUPERVISION_INTERVAL = 0.1
class BidirectionalConnection:
"""Session wrapper for bidirectional communication with concurrent task management.
Coordinates background tasks for model event processing, tool execution, and audio
handling while providing a simple interface for agent interactions.
"""
def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None:
"""Initialize connection with model and agent reference.
Args:
model: Bidirectional model instance.
agent: BidirectionalAgent instance for tool registry access.
"""
self.model = model
self.agent = agent
self.active = True
# Background processing coordination
self.background_tasks = []
self.tool_queue = asyncio.Queue()
self.audio_output_queue = asyncio.Queue()
# Task management for cleanup
self.pending_tool_tasks: dict[str, asyncio.Task] = {}
# Interruption handling (model-agnostic)
self.interrupted = False
self.interruption_lock = asyncio.Lock()
# Tool execution tracking
self.tool_count = 0
async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection:
"""Initialize bidirectional session with conycurrent background tasks.
Creates a model-specific session and starts background tasks for processing
model events, executing tools, and managing the session lifecycle.
Args:
agent: BidirectionalAgent instance.
Returns:
BidirectionalConnection: Active session with background tasks running.
"""
logger.debug("Starting bidirectional session - initializing model connection")
# Connect to model
await agent.model.connect(
system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages
)
# Create connection wrapper for background processing
session = BidirectionalConnection(model=agent.model, agent=agent)
# Start concurrent background processors IMMEDIATELY after session creation
# This is critical - Nova Sonic needs response processing during initialization
logger.debug("Starting background processors for concurrent processing")
session.background_tasks = [
asyncio.create_task(_process_model_events(session)), # Handle model responses
asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently
]
# Start main coordination cycle
session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session))
logger.debug("Session ready with %d background tasks", len(session.background_tasks))
return session
async def stop_bidirectional_connection(session: BidirectionalConnection) -> None:
"""End session and cleanup resources including background tasks.
Args:
session: BidirectionalConnection to cleanup.
"""
if not session.active:
return
logger.debug("Session cleanup starting")
session.active = False
# Cancel pending tool tasks
for _, task in session.pending_tool_tasks.items():
if not task.done():
task.cancel()
# Cancel background tasks
for task in session.background_tasks:
if not task.done():
task.cancel()
# Cancel main cycle task
if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done():
session.main_cycle_task.cancel()
# Wait for tasks to complete
all_tasks = session.background_tasks + list(session.pending_tool_tasks.values())
if hasattr(session, "main_cycle_task"):
all_tasks.append(session.main_cycle_task)
if all_tasks:
await asyncio.gather(*all_tasks, return_exceptions=True)
# Close model connection
await session.model.close()
logger.debug("Connection closed")
async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None:
"""Main event loop coordinator that runs continuously during the session.
Monitors background tasks, manages session state, and handles session lifecycle.
Provides supervision for concurrent model event processing and tool execution.
Args:
session: BidirectionalConnection to coordinate.
"""
while session.active:
try:
# Check if background processors are still running
if all(task.done() for task in session.background_tasks):
logger.debug("Session end - all processors completed")
session.active = False
break
# Check for failed background tasks
for i, task in enumerate(session.background_tasks):
if task.done() and not task.cancelled():
exception = task.exception()
if exception:
logger.error("Session error in processor %d: %s", i, str(exception))
session.active = False
raise exception
# Brief pause before next supervision check
await asyncio.sleep(SUPERVISION_INTERVAL)
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Event loop error: %s", str(e))
session.active = False
raise
async def _handle_interruption(session: BidirectionalConnection) -> None:
"""Handle interruption detection with task cancellation and audio buffer clearing.
Cancels pending tool tasks and clears audio output queues to ensure responsive
interruption handling during conversations. Protected by async lock to prevent
concurrent execution and race conditions.
Args:
session: BidirectionalConnection to handle interruption for.
"""
async with session.interruption_lock:
# If already interrupted, skip duplicate processing
if session.interrupted:
logger.debug("Interruption already in progress")
return
logger.debug("Interruption detected")
session.interrupted = True
# Cancel all pending tool execution tasks
cancelled_tools = 0
for _task_id, task in list(session.pending_tool_tasks.items()):
if not task.done():
task.cancel()
cancelled_tools += 1
logger.debug("Tool task cancelled: %s", _task_id)
if cancelled_tools > 0:
logger.debug("Tool tasks cancelled: %d", cancelled_tools)
# Clear all queued audio output events
cleared_count = 0
while True:
try:
session.audio_output_queue.get_nowait()
cleared_count += 1
except asyncio.QueueEmpty:
break
# Also clear the agent's audio output queue
audio_cleared = 0
# Create a temporary list to hold non-audio events
temp_events = []
try:
while True:
event = session.agent._output_queue.get_nowait()
if event.get("audioOutput"):
audio_cleared += 1
else:
# Keep non-audio events
temp_events.append(event)
except asyncio.QueueEmpty:
pass
# Put back non-audio events
for event in temp_events:
session.agent._output_queue.put_nowait(event)
if audio_cleared > 0:
logger.debug("Agent audio queue cleared: %d events", audio_cleared)
if cleared_count > 0:
logger.debug("Session audio queue cleared: %d events", cleared_count)
# Reset interruption flag after clearing (automatic recovery)
session.interrupted = False
logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count)
async def _process_model_events(session: BidirectionalConnection) -> None:
"""Process model events and convert them to Strands format.
Background task that handles all model responses, converts provider-specific
events to standardized formats, and manages interruption detection.
Args:
session: BidirectionalConnection containing model.
"""
logger.debug("Model events processor started")
try:
async for provider_event in session.model.receive():
if not session.active:
break
# Basic validation - skip invalid events
if not isinstance(provider_event, dict):
continue
strands_event = provider_event
# Handle interruption detection (provider converts raw patterns to interruptionDetected)
if strands_event.get("interruptionDetected"):
logger.debug("Interruption forwarded")
await _handle_interruption(session)
# Forward interruption event to agent for application-level handling
await session.agent._output_queue.put(strands_event)
continue
# Queue tool requests for concurrent execution
if strands_event.get("toolUse"):
tool_name = strands_event["toolUse"].get("name")
logger.debug("Tool usage detected: %s", tool_name)
await session.tool_queue.put(strands_event["toolUse"])
continue
# Send output events to Agent for receive() method
if strands_event.get("audioOutput") or strands_event.get("textOutput"):
await session.agent._output_queue.put(strands_event)
# Update Agent conversation history using existing patterns
if strands_event.get("messageStop"):
logger.debug("Message added to history")
session.agent.messages.append(strands_event["messageStop"]["message"])
# Handle user audio transcripts - add to message history
if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user":
user_transcript = strands_event["textOutput"]["text"]
if user_transcript.strip(): # Only add non-empty transcripts
user_message = {"role": "user", "content": user_transcript}
session.agent.messages.append(user_message)
logger.debug("User transcript added to history")
except Exception as e:
logger.error("Model events error: %s", str(e))
traceback.print_exc()
finally:
logger.debug("Model events processor stopped")
async def _process_tool_execution(session: BidirectionalConnection) -> None:
"""Execute tools concurrently with interruption support.
Background task that manages tool execution without blocking model event
processing or user interaction. Uses proper asyncio cancellation for
interruption handling rather than manual state checks.
Args:
session: BidirectionalConnection containing tool queue.
"""
logger.debug("Tool execution processor started")
while session.active:
try:
tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT)
tool_name = tool_use.get("name")
tool_id = tool_use.get("toolUseId")
session.tool_count += 1
print(f"\nTool #{session.tool_count}: {tool_name}")
logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id)
task_id = str(uuid.uuid4())
task = asyncio.create_task(_execute_tool_with_strands(session, tool_use))
session.pending_tool_tasks[task_id] = task
def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None:
try:
# Remove from pending tasks
if task_id in session.pending_tool_tasks:
del session.pending_tool_tasks[task_id]
# Log completion status
if completed_task.cancelled():
logger.debug("Tool task cancelled: %s", task_id)
elif completed_task.exception():
logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception()))
else:
logger.debug("Tool task completed: %s", task_id)
except Exception as e:
logger.error("Tool task cleanup failed: %s - %s", task_id, str(e))
task.add_done_callback(cleanup_task)
except asyncio.TimeoutError:
if not session.active:
break
# Remove completed tasks from tracking
completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()]
for task_id in completed_tasks:
if task_id in session.pending_tool_tasks:
del session.pending_tool_tasks[task_id]
if completed_tasks:
logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks))
continue
except Exception as e:
logger.error("Tool execution error: %s", str(e))
if not session.active:
break
logger.debug("Tool execution processor stopped")
async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None:
"""Execute tool using the complete Strands tool execution system.
Uses proper Strands ToolExecutor system with validation, error handling,
and event streaming.
Args:
session: BidirectionalConnection for context.
tool_use: Tool use event to execute.
"""
tool_name = tool_use.get("name")
tool_id = tool_use.get("toolUseId")
logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id)
try:
# Create message structure for validation
tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
# Use Strands validation system
tool_uses: list[ToolUse] = []
tool_results: list[ToolResult] = []
invalid_tool_use_ids: list[str] = []
validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids)
# Filter valid tools
valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids]
if not valid_tool_uses:
logger.warning("No valid tools after validation: %s", tool_name)
return
# Create invocation state for tool execution
invocation_state = {
"agent": session.agent,
"model": session.agent.model,
"messages": session.agent.messages,
"system_prompt": session.agent.system_prompt,
}
# Create cycle trace and span
cycle_trace = Trace("Bidirectional Tool Execution")
cycle_span = None
tool_events = session.agent.tool_executor._execute(
session.agent,
valid_tool_uses,
tool_results,
cycle_trace,
cycle_span,
invocation_state
)
# Process tool events and send results to provider
async for tool_event in tool_events:
if isinstance(tool_event, ToolResultEvent):
tool_result = tool_event.tool_result
tool_use_id = tool_result.get("toolUseId")
# Send result through send() method
await session.model.send(tool_result)
logger.debug("Tool result sent: %s", tool_use_id)
# Handle streaming events if needed later
elif isinstance(tool_event, ToolStreamEvent):
logger.debug("Tool stream event: %s", tool_event)
pass
# Add tool result message to conversation history
if tool_results:
from ....hooks import MessageAddedEvent
tool_result_message: Message = {
"role": "user",
"content": [{"toolResult": result} for result in tool_results],
}
session.agent.messages.append(tool_result_message)
session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message))
logger.debug("Tool result message added to history: %s", tool_name)
logger.debug("Tool execution completed: %s", tool_name)
except asyncio.CancelledError:
logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id)
raise
except Exception as e:
logger.error("Tool execution error: %s - %s", tool_name, str(e))
# Send error result
error_result: ToolResult = {
"toolUseId": tool_id,
"status": "error",
"content": [{"text": f"Error: {str(e)}"}]
}
try:
await session.model.send(error_result)
logger.debug("Error result sent: %s", tool_id)
except Exception as send_error:
logger.error("Failed to send error result: %s - %s", tool_id, str(send_error))
raise # Propagate exception since this is experimental code