Skip to content

Commit 30e6b1e

Browse files
authored
Merge pull request #20 from mkmeral/bidi-event-types
Event Types
2 parents 5a22ad9 + 330831d commit 30e6b1e

27 files changed

Lines changed: 1919 additions & 1049 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ bidirectional-streaming-nova = [
6363
]
6464
bidirectional-streaming-openai = [
6565
"pyaudio>=0.2.13",
66-
"websockets>=12.0,<14.0",
66+
"websockets>=14.0,<16.0",
6767
]
6868
bidirectional-streaming = [
6969
"pyaudio>=0.2.13",
7070
"rx>=3.2.0",
7171
"smithy-aws-core>=0.0.1",
7272
"pytz",
7373
"aws_sdk_bedrock_runtime",
74-
"websockets>=12.0,<14.0",
74+
"websockets>=14.0,<16.0",
7575
]
7676
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
7777
docs = [

src/strands/experimental/bidirectional_streaming/__init__.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,29 @@
1515
from .models.openai import BidiOpenAIRealtimeModel
1616

1717
# Event types - For type hints and event handling
18-
from .types.bidirectional_streaming import (
19-
AudioInputEvent,
20-
AudioOutputEvent,
21-
BidirectionalStreamEvent,
22-
ImageInputEvent,
23-
InterruptionDetectedEvent,
24-
TextInputEvent,
25-
TextOutputEvent,
26-
UsageMetricsEvent,
27-
VoiceActivityEvent,
18+
from .types.events import (
19+
BidiAudioInputEvent,
20+
BidiAudioStreamEvent,
21+
BidiConnectionCloseEvent,
22+
BidiConnectionStartEvent,
23+
BidiErrorEvent,
24+
BidiImageInputEvent,
25+
BidiInputEvent,
26+
BidiInterruptionEvent,
27+
ModalityUsage,
28+
BidiUsageEvent,
29+
BidiOutputEvent,
30+
BidiResponseCompleteEvent,
31+
BidiResponseStartEvent,
32+
BidiTextInputEvent,
33+
BidiTranscriptStreamEvent,
34+
)
35+
36+
# Re-export standard agent events for tool handling
37+
from ...types._events import (
38+
ToolResultEvent,
39+
ToolStreamEvent,
40+
ToolUseStreamEvent,
2841
)
2942

3043
__all__ = [
@@ -37,16 +50,30 @@
3750
"BidiNovaSonicModel",
3851
"BidiOpenAIRealtimeModel",
3952

40-
# Event types
41-
"AudioInputEvent",
42-
"AudioOutputEvent",
43-
"ImageInputEvent",
44-
"TextInputEvent",
45-
"TextOutputEvent",
46-
"InterruptionDetectedEvent",
47-
"BidirectionalStreamEvent",
48-
"VoiceActivityEvent",
49-
"UsageMetricsEvent",
53+
# Input Event types
54+
"BidiTextInputEvent",
55+
"BidiAudioInputEvent",
56+
"BidiImageInputEvent",
57+
"BidiInputEvent",
58+
59+
# Output Event types
60+
"BidiConnectionStartEvent",
61+
"BidiConnectionCloseEvent",
62+
"BidiResponseStartEvent",
63+
"BidiResponseCompleteEvent",
64+
"BidiAudioStreamEvent",
65+
"BidiTranscriptStreamEvent",
66+
"BidiInterruptionEvent",
67+
"BidiUsageEvent",
68+
"ModalityUsage",
69+
"BidiErrorEvent",
70+
"BidiOutputEvent",
71+
72+
# Tool Event types (reused from standard agent)
73+
"ToolUseStreamEvent",
74+
"ToolResultEvent",
75+
"ToolStreamEvent",
76+
5077
# Model interface
5178
"BidiModel",
5279
]

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,22 @@
2727
from ....types.content import Message, Messages
2828
from ....types.tools import ToolResult, ToolUse, AgentTool
2929

30-
from ..event_loop.bidirectional_event_loop import BidirectionalConnection
30+
from ..event_loop.bidirectional_event_loop import (
31+
BidirectionalConnection,
32+
start_bidirectional_connection,
33+
stop_bidirectional_connection,
34+
)
3135
from ..models.bidirectional_model import BidiModel
3236
from ..models.novasonic import BidiNovaSonicModel
33-
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent
37+
from ..types.agent import BidiAgentInput
38+
from ..types.events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent, BidiInputEvent, BidiOutputEvent
3439
from ..types import BidiIO
3540
from ....experimental.tools import ToolProvider
3641

3742
logger = logging.getLogger(__name__)
3843

3944
_DEFAULT_AGENT_NAME = "Strands Agents"
4045
_DEFAULT_AGENT_ID = "default"
41-
# Type alias for cleaner send() method signature
42-
BidirectionalInput = str | AudioInputEvent | ImageInputEvent
4346

4447

4548
class BidiAgent:
@@ -250,47 +253,81 @@ async def start(self) -> None:
250253
raise ValueError("Conversation already active. Call end() first.")
251254

252255
logger.debug("Conversation start - initializing connection")
256+
self._agent_loop = await start_bidirectional_connection(self)
253257

254-
# Create model session and event loop directly
255-
await self.model.start(
256-
system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages
257-
)
258-
259-
self._agent_loop = BidirectionalConnection(model=self.model, agent=self)
260-
await self._agent_loop.start()
261-
262-
logger.debug("Conversation ready")
263-
264-
async def send(self, input_data: BidirectionalInput) -> None:
265-
"""Send input to the model (text or audio).
266-
267-
Unified method for sending both text and audio input to the model during
268-
an active conversation connection. User input is automatically added to
269-
conversation history for complete message tracking.
270-
258+
async def send(self, input_data: BidiAgentInput) -> None:
259+
"""Send input to the model (text, audio, image, or event dict).
260+
261+
Unified method for sending text, audio, and image input to the model during
262+
an active conversation session. Accepts TypedEvent instances or plain dicts
263+
(e.g., from WebSocket clients) which are automatically reconstructed.
264+
271265
Args:
272-
input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images.
273-
266+
input_data: Can be:
267+
- str: Text message from user
268+
- BidiAudioInputEvent: Audio data with format/sample rate
269+
- BidiImageInputEvent: Image data with MIME type
270+
- dict: Event dictionary (will be reconstructed to TypedEvent)
271+
274272
Raises:
275-
ValueError: If no active connection or invalid input type.
273+
ValueError: If no active session or invalid input type.
274+
275+
Example:
276+
await agent.send("Hello")
277+
await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...))
278+
await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"})
276279
"""
277280
self._validate_active_connection()
278281

282+
# Handle string input
279283
if isinstance(input_data, str):
280284
# Add user text message to history
281285
user_message: Message = {"role": "user", "content": [{"text": input_data}]}
282286

283287
self.messages.append(user_message)
284288

285289
logger.debug("Text sent: %d characters", len(input_data))
286-
# Create TextInputEvent for send()
287-
text_event = {"text": input_data, "role": "user"}
290+
# Create BidiTextInputEvent for send()
291+
text_event = BidiTextInputEvent(text=input_data, role="user")
288292
await self._agent_loop.model.send(text_event)
289-
else:
290-
# For audio, image, or any other input - let model handle it
293+
return
294+
295+
# Handle BidiInputEvent instances
296+
# Check this before dict since TypedEvent inherits from dict
297+
if isinstance(input_data, BidiInputEvent):
291298
await self._agent_loop.model.send(input_data)
299+
return
300+
301+
# Handle plain dict - reconstruct TypedEvent for WebSocket integration
302+
if isinstance(input_data, dict) and "type" in input_data:
303+
event_type = input_data["type"]
304+
if event_type == "bidi_text_input":
305+
input_event = BidiTextInputEvent(text=input_data["text"], role=input_data["role"])
306+
elif event_type == "bidi_audio_input":
307+
input_event = BidiAudioInputEvent(
308+
audio=input_data["audio"],
309+
format=input_data["format"],
310+
sample_rate=input_data["sample_rate"],
311+
channels=input_data["channels"]
312+
)
313+
elif event_type == "bidi_image_input":
314+
input_event = BidiImageInputEvent(
315+
image=input_data["image"],
316+
mime_type=input_data["mime_type"]
317+
)
318+
else:
319+
raise ValueError(f"Unknown event type: {event_type}")
320+
321+
# Send the reconstructed TypedEvent
322+
await self._agent_loop.model.send(input_event)
323+
return
324+
325+
# If we get here, input type is invalid
326+
raise ValueError(
327+
f"Input must be a string, BidiInputEvent (BidiTextInputEvent/BidiAudioInputEvent/BidiImageInputEvent), or event dict with 'type' field, got: {type(input_data)}"
328+
)
292329

293-
async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]:
330+
async def receive(self) -> AsyncIterable[BidiOutputEvent]:
294331
"""Receive events from the model including audio, text, and tool calls.
295332
296333
Yields model output events processed by background tasks including audio output,
@@ -301,9 +338,11 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]:
301338
"""
302339
while self.active:
303340
try:
304-
event = await self._output_queue.get()
341+
# Use a timeout to periodically check if we should stop
342+
event = await asyncio.wait_for(self._output_queue.get(), timeout=0.5)
305343
yield event
306344
except asyncio.TimeoutError:
345+
# Timeout allows us to check self.active periodically
307346
continue
308347

309348
async def stop(self) -> None:
@@ -313,7 +352,7 @@ async def stop(self) -> None:
313352
closes the connection to the model provider.
314353
"""
315354
if self._agent_loop:
316-
await self._agent_loop.stop()
355+
await stop_bidirectional_connection(self._agent_loop)
317356
self._agent_loop = None
318357

319358
async def __aenter__(self) -> "BidiAgent":

src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None:
223223
try:
224224
while True:
225225
event = session.agent._output_queue.get_nowait()
226-
if event.get("audioOutput"):
226+
# Check for audio events
227+
event_type = event.get("type", "")
228+
if event_type == "bidi_audio_stream":
227229
audio_cleared += 1
228230
else:
229231
# Keep non-audio events
@@ -267,35 +269,38 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
267269

268270
strands_event = provider_event
269271

270-
# Handle interruption detection (provider converts raw patterns to interruptionDetected)
271-
if strands_event.get("interruptionDetected"):
272+
# Get event type
273+
event_type = strands_event.get("type", "")
274+
275+
# Handle interruption detection
276+
if event_type == "bidi_interruption":
272277
logger.debug("Interruption forwarded")
273278
await _handle_interruption(session)
274279
# Forward interruption event to agent for application-level handling
275280
await session.agent._output_queue.put(strands_event)
276281
continue
277282

278283
# Queue tool requests for concurrent execution
279-
if strands_event.get("toolUse"):
280-
tool_name = strands_event["toolUse"].get("name")
281-
logger.debug("Tool usage detected: %s", tool_name)
282-
await session.tool_queue.put(strands_event["toolUse"])
284+
# Check for ToolUseStreamEvent (standard agent event)
285+
if event_type == "tool_use_stream":
286+
tool_use = strands_event.get("current_tool_use")
287+
if tool_use:
288+
tool_name = tool_use.get("name")
289+
logger.debug("Tool usage detected: %s", tool_name)
290+
await session.tool_queue.put(tool_use)
291+
# Forward ToolUseStreamEvent to output queue for client visibility
292+
await session.agent._output_queue.put(strands_event)
283293
continue
284294

285-
# Send output events to Agent for receive() method
286-
if strands_event.get("audioOutput") or strands_event.get("textOutput"):
287-
await session.agent._output_queue.put(strands_event)
295+
# Send all output events to Agent for receive() method
296+
await session.agent._output_queue.put(strands_event)
288297

289-
# Update Agent conversation history using existing patterns
290-
if strands_event.get("messageStop"):
291-
logger.debug("Message added to history")
292-
session.agent.messages.append(strands_event["messageStop"]["message"])
293-
294-
# Handle user audio transcripts - add to message history
295-
if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user":
296-
user_transcript = strands_event["textOutput"]["text"]
297-
if user_transcript.strip(): # Only add non-empty transcripts
298-
user_message = {"role": "user", "content": user_transcript}
298+
# Update Agent conversation history for user transcripts
299+
if event_type == "bidi_transcript_stream":
300+
role = strands_event.get("role")
301+
text = strands_event.get("text", "")
302+
if role == "user" and text.strip():
303+
user_message = {"role": "user", "content": text}
299304
session.agent.messages.append(user_message)
300305
logger.debug("User transcript added to history")
301306

@@ -434,14 +439,19 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use:
434439
tool_result = tool_event.tool_result
435440
tool_use_id = tool_result.get("toolUseId")
436441

437-
# Send result through send() method
438-
await session.model.send(tool_result)
439-
logger.debug("Tool result sent: %s", tool_use_id)
442+
# Send ToolResultEvent through send() method to model
443+
await session.model.send(tool_event)
444+
logger.debug("Tool result sent to model: %s", tool_use_id)
445+
446+
# Also forward ToolResultEvent to output queue for client visibility
447+
await session.agent._output_queue.put(tool_event)
448+
logger.debug("Tool result sent to client: %s", tool_use_id)
440449

441450
# Handle streaming events if needed later
442451
elif isinstance(tool_event, ToolStreamEvent):
443452
logger.debug("Tool stream event: %s", tool_event)
444-
pass
453+
# Forward tool stream events to output queue
454+
await session.agent._output_queue.put(tool_event)
445455

446456
# Add tool result message to conversation history
447457
if tool_results:
@@ -464,14 +474,14 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use:
464474
except Exception as e:
465475
logger.error("Tool execution error: %s - %s", tool_name, str(e))
466476

467-
# Send error result
477+
# Send error result wrapped in ToolResultEvent
468478
error_result: ToolResult = {
469479
"toolUseId": tool_id,
470480
"status": "error",
471481
"content": [{"text": f"Error: {str(e)}"}]
472482
}
473483
try:
474-
await session.model.send(error_result)
484+
await session.model.send(ToolResultEvent(error_result))
475485
logger.debug("Error result sent: %s", tool_id)
476486
except Exception as send_error:
477487
logger.error("Failed to send error result: %s - %s", tool_id, str(send_error))

0 commit comments

Comments
 (0)