Skip to content

Commit 0a63829

Browse files
committed
Update implementation based on bar-raising
- Remove adapter from constructor - Implement BidirectionlIO interface - Add adapter the run() method
1 parent 2a2861b commit 0a63829

6 files changed

Lines changed: 193 additions & 166 deletions

File tree

src/strands/experimental/bidirectional_streaming/adapters/__init__.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 62 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626
from ....tools.watcher import ToolWatcher
2727
from ....types.content import Message, Messages
2828
from ....types.tools import ToolResult, ToolUse, AgentTool
29-
from ..adapters.audio_adapter import AudioAdapter
29+
3030
from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop
3131
from ..models.bidirectional_model import BidirectionalModel
3232
from ..models.novasonic import NovaSonicModel
3333
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent
34+
from ..types import BidirectionalIO
3435
from ....experimental.tools import ToolProvider
3536

3637
logger = logging.getLogger(__name__)
@@ -60,7 +61,6 @@ def __init__(
6061
name: Optional[str] = None,
6162
tool_executor: Optional[ToolExecutor] = None,
6263
description: Optional[str] = None,
63-
adapters: Optional[list[Any]] = None,
6464
**kwargs: Any,
6565
):
6666
"""Initialize bidirectional agent.
@@ -76,8 +76,6 @@ def __init__(
7676
name: Name of the Agent.
7777
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
7878
description: Description of what the Agent does.
79-
adapters: Optional list of adapter instances (e.g., AudioAdapter) for hardware abstraction.
80-
If None, automatically creates default AudioAdapter for basic audio functionality.
8179
**kwargs: Additional configuration for future extensibility.
8280
8381
Raises:
@@ -125,14 +123,7 @@ def __init__(
125123
# connection management
126124
self._agentloop: Optional["BidirectionalAgentLoop"] = None
127125
self._output_queue = asyncio.Queue()
128-
129-
# Initialize adapters - auto-create AudioAdapter as default
130-
if adapters is None:
131-
# Create default AudioAdapter for basic audio functionality
132-
default_audio_adapter = AudioAdapter(audio_config={"input_sample_rate": 16000})
133-
self.adapters = [default_audio_adapter]
134-
else:
135-
self.adapters = adapters
126+
self._current_adapters = [] # Track adapters for cleanup
136127

137128
@property
138129
def tool(self) -> ToolCaller:
@@ -261,11 +252,11 @@ async def start(self) -> None:
261252
logger.debug("Conversation start - initializing connection")
262253

263254
# Create model session and event loop directly
264-
model_session = await self.model.create_bidirectional_connection(
255+
model_session = await self.model.connect(
265256
system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages
266257
)
267258

268-
self._agentloop = BidirectionalAgentLoop(model_session=model_session, agent=self)
259+
self._agentloop = BidirectionalAgentLoop(model=self.model, agent=self)
269260
await self._agentloop.start()
270261

271262
logger.debug("Conversation ready")
@@ -294,13 +285,13 @@ async def send(self, input_data: BidirectionalInput) -> None:
294285
logger.debug("Text sent: %d characters", len(input_data))
295286
# Create TextInputEvent for send()
296287
text_event = {"text": input_data, "role": "user"}
297-
await self._agentloop.model_session.send(text_event)
288+
await self._agentloop.model.send(text_event)
298289
elif isinstance(input_data, dict) and "audioData" in input_data:
299290
# Handle audio input
300-
await self._agentloop.model_session.send(input_data)
291+
await self._agentloop.model.send(input_data)
301292
elif isinstance(input_data, dict) and "imageData" in input_data:
302293
# Handle image input (ImageInputEvent)
303-
await self._agentloop.model_session.send(input_data)
294+
await self._agentloop.model.send(input_data)
304295
else:
305296
raise ValueError(
306297
"Input must be either a string (text), AudioInputEvent "
@@ -363,17 +354,20 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
363354
"""
364355
try:
365356
logger.debug("Exiting async context manager - cleaning up adapters and connection")
366-
367-
# Cleanup adapters first
368-
for adapter in self.adapters:
369-
if hasattr(adapter, "_cleanup_audio"):
357+
358+
# Cleanup adapters if any are currently active
359+
for adapter in self._current_adapters:
360+
if hasattr(adapter, "cleanup"):
370361
try:
371-
adapter._cleanup_audio()
362+
adapter.cleanup()
372363
logger.debug(f"Cleaned up adapter: {type(adapter).__name__}")
373364
except Exception as adapter_error:
374365
logger.warning(f"Error cleaning up adapter: {adapter_error}")
375-
376-
# Then cleanup agent connection
366+
367+
# Clear current adapters
368+
self._current_adapters = []
369+
370+
# Cleanup agent connection
377371
await self.end()
378372

379373
except Exception as cleanup_error:
@@ -396,72 +390,72 @@ def active(self) -> bool:
396390
"""
397391
return self._agentloop is not None and self._agentloop.active
398392

399-
async def connect(self) -> None:
400-
"""Connect the agent using configured adapters for bidirectional communication.
401-
402-
Automatically uses configured adapters to establish bidirectional communication
403-
with the model. If no adapters are provided in constructor, uses default AudioAdapter.
393+
async def run(self, io_channels: list[BidirectionalIO | tuple[Callable, Callable]]) -> None:
394+
"""Run the agent using provided IO channels or transport tuples for bidirectional communication.
404395
396+
Args:
397+
io_channels: List containing either BidirectionalIO instances or (sender, receiver) tuples.
398+
- BidirectionalIO: IO channel instance with input_channel(), output_channel(), and cleanup() methods
399+
- tuple: (sender_callable, receiver_callable) for custom transport
400+
405401
Example:
406402
```python
407-
# Simple - uses default AudioAdapter
403+
# With IO channel
404+
audio_io = AudioIO(audio_config={"input_sample_rate": 16000})
408405
agent = BidirectionalAgent(model=model, tools=[calculator])
409-
await agent.connect()
406+
await agent.run(io_channels=[audio_io])
410407
411-
# Custom adapter
412-
adapter = AudioAdapter(audio_config={"input_sample_rate": 24000})
413-
agent = BidirectionalAgent(model=model, tools=[calculator], adapters=[adapter])
414-
await agent.connect()
408+
# With tuple (backward compatibility)
409+
await agent.run(io_channels=[(sender_function, receiver_function)])
415410
```
416411
417412
Raises:
413+
ValueError: If io_channels list is empty or contains invalid items.
418414
Exception: Any exception from the transport layer.
419415
"""
420-
# Use first adapter (always available due to default initialization)
421-
adapter = self.adapters[0]
422-
sender = adapter.create_output()
423-
receiver = adapter.create_input()
416+
if not io_channels:
417+
raise ValueError("io_channels parameter cannot be empty. Provide either an IO channel or (sender, receiver) tuple.")
418+
419+
transport = io_channels[0]
420+
421+
# Set IO channel tracking for cleanup
422+
if hasattr(transport, 'input_channel') and hasattr(transport, 'output_channel'):
423+
self._current_adapters = [transport] # IO channel needs cleanup
424+
elif isinstance(transport, tuple) and len(transport) == 2:
425+
self._current_adapters = [] # Tuple needs no cleanup
426+
else:
427+
raise ValueError("io_channels list must contain either BidirectionalIO instances or (sender, receiver) tuples.")
424428

429+
# Auto-manage session lifecycle
425430
if self.active:
426-
# Use existing connection
427-
await self._run(sender, receiver)
431+
await self._run_with_transport(transport)
428432
else:
429-
# Use async context manager for automatic lifecycle management
430433
async with self:
431-
await self._run(sender, receiver)
434+
await self._run_with_transport(transport)
432435

433-
async def _run(
436+
async def _run_with_transport(
434437
self,
435-
sender: Callable[[Any], Any],
436-
receiver: Callable[[], Any],
438+
transport: BidirectionalIO | tuple[Callable, Callable],
437439
) -> None:
438-
"""Internal method to run send/receive loops with an active connection.
439-
440-
Args:
441-
sender: Async callable that sends events to the client.
442-
receiver: Async callable that receives events from the client.
443-
"""
440+
"""Internal method to run send/receive loops with an active connection."""
444441

445442
async def receive_from_agent():
446-
"""Receive events from agent and send to client."""
447-
try:
448-
async for event in self.receive():
449-
await sender(event)
450-
except Exception as e:
451-
logger.debug(f"Receive from agent stopped: {e}")
452-
raise
443+
"""Receive events from agent and send to transport."""
444+
async for event in self.receive():
445+
if hasattr(transport, 'output_channel'):
446+
await transport.output_channel(event)
447+
else:
448+
await transport[0](event)
453449

454450
async def send_to_agent():
455-
"""Receive events from client and send to agent."""
456-
try:
457-
while self.active:
458-
event = await receiver()
459-
await self.send(event)
460-
except Exception as e:
461-
logger.debug(f"Send to agent stopped: {e}")
462-
raise
451+
"""Receive events from transport and send to agent."""
452+
while self.active:
453+
if hasattr(transport, 'input_channel'):
454+
event = await transport.input_channel()
455+
else:
456+
event = await transport[1]()
457+
await self.send(event)
463458

464-
# Run both loops concurrently
465459
await asyncio.gather(receive_from_agent(), send_to_agent(), return_exceptions=True)
466460

467461
def _validate_active_connection(self) -> None:

src/strands/experimental/bidirectional_streaming/tests/test_bidi.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
88

99
from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent
10-
from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel
10+
from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel
11+
from strands.experimental.bidirectional_streaming.types.audio_io import AudioIO
1112
from strands_tools import calculator
1213

1314

@@ -16,12 +17,13 @@ async def main():
1617

1718

1819
# Nova Sonic model
19-
model = NovaSonicBidirectionalModel()
20+
adapter = AudioIO()
21+
model = NovaSonicModel(region="us-east-1")
2022

2123
async with BidirectionalAgent(model=model, tools=[calculator]) as agent:
2224
print("New BidirectionalAgent Experience")
2325
print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'")
24-
await agent.connect()
26+
await agent.run(io_channels=[adapter])
2527

2628

2729
if __name__ == "__main__":

src/strands/experimental/bidirectional_streaming/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Type definitions for bidirectional streaming."""
22

3+
from .audio_io import AudioIO
4+
from .bidirectional_io import BidirectionalIO
35
from .bidirectional_streaming import (
46
DEFAULT_CHANNELS,
57
DEFAULT_SAMPLE_RATE,
@@ -20,6 +22,8 @@
2022
)
2123

2224
__all__ = [
25+
"AudioIO",
26+
"BidirectionalIO",
2327
"AudioInputEvent",
2428
"AudioOutputEvent",
2529
"BidirectionalConnectionEndEvent",

0 commit comments

Comments
 (0)