Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ bidirectional-streaming-nova = [
]
bidirectional-streaming-openai = [
"pyaudio>=0.2.13",
"websockets>=12.0,<14.0",
"websockets>=14.0,<16.0",
]
bidirectional-streaming = [
"pyaudio>=0.2.13",
"rx>=3.2.0",
"smithy-aws-core>=0.0.1",
"pytz",
"aws_sdk_bedrock_runtime",
"websockets>=12.0,<14.0",
"websockets>=14.0,<16.0",
]
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
docs = [
Expand Down
54 changes: 40 additions & 14 deletions src/strands/experimental/bidirectional_streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@
# Event types - For type hints and event handling
from .types.bidirectional_streaming import (
AudioInputEvent,
AudioOutputEvent,
BidirectionalStreamEvent,
AudioStreamEvent,
ConnectionCloseEvent,
ConnectionStartEvent,
ErrorEvent,
ImageInputEvent,
InterruptionDetectedEvent,
InputEvent,
InterruptionEvent,
ModalityUsage,
UsageEvent,
OutputEvent,
ResponseCompleteEvent,
ResponseStartEvent,
TextInputEvent,
TextOutputEvent,
UsageMetricsEvent,
VoiceActivityEvent,
TranscriptStreamEvent,
)

# Re-export standard agent events for tool handling
from ...types._events import (
ToolResultEvent,
ToolStreamEvent,
ToolUseStreamEvent,
)

__all__ = [
Expand All @@ -33,16 +46,29 @@
"NovaSonicModel",
"OpenAIRealtimeModel",

# Event types
# Input Event types
"TextInputEvent",
"AudioInputEvent",
"AudioOutputEvent",
"ImageInputEvent",
"TextInputEvent",
"TextOutputEvent",
"InterruptionDetectedEvent",
"BidirectionalStreamEvent",
"VoiceActivityEvent",
"UsageMetricsEvent",
"InputEvent",

# Output Event types
"ConnectionStartEvent",
"ConnectionCloseEvent",
"ResponseStartEvent",
"ResponseCompleteEvent",
"AudioStreamEvent",
"TranscriptStreamEvent",
"InterruptionEvent",
"UsageEvent",
"ModalityUsage",
"ErrorEvent",
"OutputEvent",

# Tool Event types (reused from standard agent)
"ToolUseStreamEvent",
"ToolResultEvent",
"ToolStreamEvent",

# Model interface
"BidirectionalModel",
Expand Down
89 changes: 69 additions & 20 deletions src/strands/experimental/bidirectional_streaming/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
from ....types.traces import AttributeValue
from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection
from ..models.bidirectional_model import BidirectionalModel
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent
from ..types.bidirectional_streaming import (
AudioInputEvent,
ImageInputEvent,
InputEvent,
OutputEvent,
TextInputEvent,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,6 +171,8 @@ def __init__(
hooks: Optional[list[HookProvider]] = None,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
description: Optional[str] = None,
enable_reconnection: bool = True,
max_reconnection_attempts: int = 3,
):
"""Initialize bidirectional agent with required model and optional configuration.

Expand All @@ -181,11 +189,17 @@ def __init__(
hooks: Hooks to be added to the agent hook registry.
trace_attributes: Custom trace attributes to apply to the agent's trace span.
description: Description of what the Agent does.
enable_reconnection: Whether to automatically reconnect on connection failures (default: True).
max_reconnection_attempts: Maximum number of reconnection attempts (default: 3).
"""
self.model = model
self.system_prompt = system_prompt
self.messages = messages or []

# Reconnection configuration
self.enable_reconnection = enable_reconnection
self.max_reconnection_attempts = max_reconnection_attempts

# Agent identification
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
Expand Down Expand Up @@ -360,53 +374,88 @@ async def start(self) -> None:
logger.debug("Conversation start - initializing session")
self._session = await start_bidirectional_connection(self)

async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None:
"""Send input to the model (text, audio, or image).
async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) -> None:
"""Send input to the model (text, audio, image, or event dict).

Unified method for sending text, audio, and image input to the model during
an active conversation session.
an active conversation session. Accepts TypedEvent instances or plain dicts
(e.g., from WebSocket clients) which are automatically reconstructed.

Args:
input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images.
input_data: Can be:
- str: Text message from user
- AudioInputEvent: Audio data with format/sample rate
- ImageInputEvent: Image data with MIME type
- dict: Event dictionary (will be reconstructed to TypedEvent)

Raises:
ValueError: If no active session or invalid input type.

Example:
await agent.send("Hello")
await agent.send(AudioInputEvent(audio="base64...", format="pcm", ...))
await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"})
"""
self._validate_active_session()

# Handle string input
if isinstance(input_data, str):
# Add user text message to history
self.messages.append({"role": "user", "content": input_data})

logger.debug("Text sent: %d characters", len(input_data))
# Create TextInputEvent for send()
text_event = {"text": input_data, "role": "user"}
text_event = TextInputEvent(text=input_data, role="user")
await self._session.model.send(text_event)
elif isinstance(input_data, dict) and "audioData" in input_data:
# Handle audio input - already in AudioInputEvent format
return

# Handle InputEvent instances (TextInputEvent, AudioInputEvent, ImageInputEvent)
# Check this before dict since TypedEvent inherits from dict
if isinstance(input_data, (TextInputEvent, AudioInputEvent, ImageInputEvent)):
await self._session.model.send(input_data)
elif isinstance(input_data, dict) and "imageData" in input_data:
# Handle image input - already in ImageInputEvent format
return

# Handle plain dict - reconstruct TypedEvent for WebSocket integration
if isinstance(input_data, dict) and "type" in input_data:
event_type = input_data["type"]
if event_type == "bidirectional_text_input":
input_data = TextInputEvent(text=input_data["text"], role=input_data["role"])
elif event_type == "bidirectional_audio_input":
input_data = AudioInputEvent(
audio=input_data["audio"],
format=input_data["format"],
sample_rate=input_data["sample_rate"],
channels=input_data["channels"]
)
elif event_type == "bidirectional_image_input":
input_data = ImageInputEvent(
image=input_data["image"],
mime_type=input_data["mime_type"]
)
else:
raise ValueError(f"Unknown event type: {event_type}")

# Send the reconstructed TypedEvent
await self._session.model.send(input_data)
else:
raise ValueError(
"Input must be either a string (text), AudioInputEvent "
"(dict with audioData, format, sampleRate, channels), or ImageInputEvent "
"(dict with imageData, mimeType, encoding)"
)
return

# If we get here, input type is invalid
raise ValueError(
f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}"
)

async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]:
async def receive(self) -> AsyncIterable["OutputEvent"]:
"""Receive events from the model including audio, text, and tool calls.

Yields model output events processed by background tasks including audio output,
text responses, tool calls, and session updates.

Yields:
BidirectionalStreamEvent: Events from the model session.
OutputEvent: TypedEvent objects from the model session. Events are
JSON-serializable by default (use json.dumps(event) for transport).
"""
while self._session and self._session.active:
try:
event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1)
# Return TypedEvent objects directly (JSON-serializable by default)
yield event
except asyncio.TimeoutError:
continue
Expand Down
Loading
Loading