Skip to content

Commit 328b5dd

Browse files
author
Murat Kaan Meral
committed
add timeout to agent.receive and fix integ tests
1 parent 5eca8f9 commit 328b5dd

6 files changed

Lines changed: 34 additions & 50 deletions

File tree

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,11 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]:
339339
"""
340340
while self.active:
341341
try:
342-
event = await self._output_queue.get()
342+
# Use a timeout to periodically check if we should stop
343+
event = await asyncio.wait_for(self._output_queue.get(), timeout=0.5)
343344
yield event
344345
except asyncio.TimeoutError:
346+
# Timeout allows us to check self.active periodically
345347
continue
346348

347349
async def stop(self) -> None:

tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
- Event receiving and conversion
88
"""
99

10+
import base64
11+
import json
1012
import unittest.mock
1113

1214
import pytest
@@ -16,8 +18,13 @@
1618
from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel
1719
from strands.experimental.bidirectional_streaming.types.events import (
1820
BidiAudioInputEvent,
21+
BidiAudioStreamEvent,
22+
BidiConnectionCloseEvent,
23+
BidiConnectionStartEvent,
1924
BidiImageInputEvent,
25+
BidiInterruptionEvent,
2026
BidiTextInputEvent,
27+
BidiTranscriptStreamEvent,
2128
)
2229
from strands.types._events import ToolResultEvent
2330
from strands.types.tools import ToolResult
@@ -198,7 +205,6 @@ async def test_send_all_content_types(mock_genai_client, model):
198205
assert content.parts[0].text == "Hello"
199206

200207
# Test audio input (base64 encoded)
201-
import base64
202208
audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8')
203209
audio_input = BidiAudioInputEvent(
204210
audio=audio_b64,
@@ -219,7 +225,6 @@ async def test_send_all_content_types(mock_genai_client, model):
219225
mock_live_session.send.assert_called_once()
220226

221227
# Test tool result
222-
from strands.types._events import ToolResultEvent
223228
tool_result: ToolResult = {
224229
"toolUseId": "tool-123",
225230
"status": "success",
@@ -255,11 +260,6 @@ async def test_send_edge_cases(mock_genai_client, model):
255260
@pytest.mark.asyncio
256261
async def test_receive_lifecycle_events(mock_genai_client, model, agenerator):
257262
"""Test that receive() emits connection start and end events."""
258-
from strands.experimental.bidirectional_streaming.types.events import (
259-
BidiConnectionStartEvent,
260-
BidiConnectionCloseEvent,
261-
)
262-
263263
_, mock_live_session, _ = mock_genai_client
264264
mock_live_session.receive.return_value = agenerator([])
265265

@@ -285,12 +285,6 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator):
285285
@pytest.mark.asyncio
286286
async def test_event_conversion(mock_genai_client, model):
287287
"""Test conversion of all Gemini Live event types to standard format."""
288-
from strands.experimental.bidirectional_streaming.types.events import (
289-
BidiTranscriptStreamEvent,
290-
BidiAudioStreamEvent,
291-
BidiInterruptionEvent,
292-
)
293-
294288
_, _, _ = mock_genai_client
295289
await model.start()
296290

@@ -311,7 +305,6 @@ async def test_event_conversion(mock_genai_client, model):
311305
assert text_event.current_transcript == "Hello from Gemini"
312306

313307
# Test audio output (base64 encoded)
314-
import base64
315308
mock_audio = unittest.mock.Mock()
316309
mock_audio.text = None
317310
mock_audio.data = b"audio_data"

tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
from strands.experimental.bidirectional_streaming.models.novasonic import (
1616
BidiNovaSonicModel,
1717
)
18+
from strands.experimental.bidirectional_streaming.types.events import (
19+
BidiAudioInputEvent,
20+
BidiAudioStreamEvent,
21+
BidiImageInputEvent,
22+
BidiInterruptionEvent,
23+
BidiResponseStartEvent,
24+
BidiTextInputEvent,
25+
BidiTranscriptStreamEvent,
26+
BidiUsageEvent,
27+
)
28+
from strands.types._events import ToolResultEvent
1829
from strands.types.tools import ToolResult
1930

2031

@@ -131,12 +142,6 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model
131142
@pytest.mark.asyncio
132143
async def test_send_all_content_types(nova_model, mock_client, mock_stream):
133144
"""Test sending all content types through unified send() method."""
134-
from strands.experimental.bidirectional_streaming.types.events import (
135-
BidiTextInputEvent,
136-
BidiAudioInputEvent,
137-
)
138-
from strands.types._events import ToolResultEvent
139-
140145
with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock):
141146
nova_model.client = mock_client
142147

@@ -177,11 +182,6 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream):
177182
@pytest.mark.asyncio
178183
async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog):
179184
"""Test send() edge cases and error handling."""
180-
from strands.experimental.bidirectional_streaming.types.events import (
181-
BidiTextInputEvent,
182-
BidiImageInputEvent,
183-
)
184-
185185
with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock):
186186
nova_model.client = mock_client
187187

@@ -191,7 +191,6 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog):
191191

192192
# Test image content (not supported, base64 encoded, no encoding parameter)
193193
await nova_model.start()
194-
import base64
195194
image_b64 = base64.b64encode(b"image data").decode('utf-8')
196195
image_event = BidiImageInputEvent(
197196
image=image_b64,
@@ -237,7 +236,6 @@ async def mock_wait_for(*args, **kwargs):
237236
async def test_event_conversion(nova_model):
238237
"""Test conversion of all Nova Sonic event types to standard format."""
239238
# Test audio output (now returns BidiAudioStreamEvent)
240-
from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent
241239
audio_bytes = b"test audio data"
242240
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
243241
nova_event = {"audioOutput": {"content": audio_base64}}
@@ -251,7 +249,6 @@ async def test_event_conversion(nova_model):
251249
assert result.get("sample_rate") == 24000
252250

253251
# Test text output (now returns BidiTranscriptStreamEvent)
254-
from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent
255252
nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}}
256253
result = nova_model._convert_nova_event(nova_event)
257254
assert result is not None
@@ -282,7 +279,6 @@ async def test_event_conversion(nova_model):
282279
assert tool_use["input"] == tool_input
283280

284281
# Test interruption (now returns BidiInterruptionEvent)
285-
from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent
286282
nova_event = {"stopReason": "INTERRUPTED"}
287283
result = nova_model._convert_nova_event(nova_event)
288284
assert result is not None
@@ -291,7 +287,6 @@ async def test_event_conversion(nova_model):
291287
assert result.get("reason") == "user_speech"
292288

293289
# Test usage metrics (now returns BidiUsageEvent)
294-
from strands.experimental.bidirectional_streaming.types.events import BidiUsageEvent
295290
nova_event = {
296291
"usageEvent": {
297292
"totalTokens": 100,
@@ -315,7 +310,6 @@ async def test_event_conversion(nova_model):
315310
assert result.get("outputTokens") == 60
316311

317312
# Test content start tracks role and emits BidiResponseStartEvent
318-
from strands.experimental.bidirectional_streaming.types.events import BidiResponseStartEvent
319313
nova_event = {"contentStart": {"role": "USER"}}
320314
result = nova_model._convert_nova_event(nova_event)
321315
assert result is not None
@@ -349,16 +343,13 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream):
349343
@pytest.mark.asyncio
350344
async def test_silence_detection(nova_model, mock_client, mock_stream):
351345
"""Test that silence detection automatically ends audio input."""
352-
from strands.experimental.bidirectional_streaming.types.events import BidiAudioInputEvent
353-
354346
with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock):
355347
nova_model.client = mock_client
356348
nova_model.silence_threshold = 0.1 # Short threshold for testing
357349

358350
await nova_model.start()
359351

360352
# Send audio to start connection (base64 encoded)
361-
import base64
362353
audio_b64 = base64.b64encode(b"audio data").decode('utf-8')
363354
audio_event = BidiAudioInputEvent(
364355
audio=audio_b64,

tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel
1919
from strands.experimental.bidirectional_streaming.types.events import (
2020
BidiAudioInputEvent,
21+
BidiAudioStreamEvent,
2122
BidiImageInputEvent,
23+
BidiInterruptionEvent,
2224
BidiTextInputEvent,
25+
BidiTranscriptStreamEvent,
2326
)
27+
from strands.types._events import ToolResultEvent
2428
from strands.types.tools import ToolResult
2529

2630

@@ -222,8 +226,6 @@ async def async_connect(*args, **kwargs):
222226
@pytest.mark.asyncio
223227
async def test_send_all_content_types(mock_websockets_connect, model):
224228
"""Test sending all content types through unified send() method."""
225-
from strands.types._events import ToolResultEvent
226-
227229
_, mock_ws = mock_websockets_connect
228230
await model.start()
229231

@@ -343,7 +345,6 @@ async def test_event_conversion(mock_websockets_connect, model):
343345
await model.start()
344346

345347
# Test audio output (now returns list with BidiAudioStreamEvent)
346-
from strands.experimental.bidirectional_streaming.types.events import BidiAudioStreamEvent
347348
audio_event = {
348349
"type": "response.output_audio.delta",
349350
"delta": base64.b64encode(b"audio_data").decode()
@@ -357,7 +358,6 @@ async def test_event_conversion(mock_websockets_connect, model):
357358
assert converted[0].get("format") == "pcm"
358359

359360
# Test text output (now returns list with BidiTranscriptStreamEvent)
360-
from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent
361361
text_event = {
362362
"type": "response.output_text.delta",
363363
"delta": "Hello from OpenAI"
@@ -407,7 +407,6 @@ async def test_event_conversion(mock_websockets_connect, model):
407407
assert tool_use["input"]["expression"] == "2+2"
408408

409409
# Test voice activity (now returns list with BidiInterruptionEvent for speech_started)
410-
from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent
411410
speech_started = {
412411
"type": "input_audio_buffer.speech_started"
413412
}
@@ -465,7 +464,6 @@ def test_helper_methods(model):
465464
model._active = False
466465

467466
# Test _create_text_event (now returns BidiTranscriptStreamEvent)
468-
from strands.experimental.bidirectional_streaming.types.events import BidiTranscriptStreamEvent
469467
text_event = model._create_text_event("Hello", "user")
470468
assert isinstance(text_event, BidiTranscriptStreamEvent)
471469
assert text_event.get("type") == "bidi_transcript_stream"
@@ -476,7 +474,6 @@ def test_helper_methods(model):
476474
assert text_event.current_transcript == "Hello"
477475

478476
# Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started)
479-
from strands.experimental.bidirectional_streaming.types.events import BidiInterruptionEvent
480477
voice_event = model._create_voice_activity_event("speech_started")
481478
assert isinstance(voice_event, BidiInterruptionEvent)
482479
assert voice_event.get("type") == "bidi_interruption"

tests_integ/bidirectional_streaming/context.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import asyncio
8+
import base64
89
import logging
910
import time
1011
from typing import TYPE_CHECKING
@@ -81,12 +82,13 @@ async def __aenter__(self):
8182

8283
async def __aexit__(self, exc_type, exc_val, exc_tb):
8384
"""Stop context manager, cleanup threads, and end agent session."""
84-
await self.stop()
85-
86-
# End agent session
85+
# End agent session FIRST - this will cause receive() to exit cleanly
8786
if self.agent._agent_loop and self.agent._agent_loop.active:
8887
await self.agent.stop()
89-
logger.debug("Agent session ended")
88+
logger.debug("Agent session stopped")
89+
90+
# Then stop the context threads
91+
await self.stop()
9092

9193
return False
9294

@@ -254,8 +256,6 @@ def get_audio_outputs(self) -> list[bytes]:
254256
Returns:
255257
List of audio data bytes.
256258
"""
257-
import base64
258-
259259
# Drain queue first to get latest events
260260
events = self.get_events()
261261
audio_data = []
@@ -332,6 +332,7 @@ async def _input_thread(self):
332332

333333
except asyncio.CancelledError:
334334
logger.debug("Input thread cancelled")
335+
raise # Re-raise to properly propagate cancellation
335336
except Exception as e:
336337
logger.error(f"Input thread error: {e}", exc_info=True)
337338
finally:
@@ -350,6 +351,7 @@ async def _event_collection_thread(self):
350351

351352
except asyncio.CancelledError:
352353
logger.debug("Event collection thread cancelled")
354+
raise # Re-raise to properly propagate cancellation
353355
except Exception as e:
354356
logger.error(f"Event collection thread error: {e}")
355357

tests_integ/bidirectional_streaming/generators/audio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
without requiring physical audio devices or pre-recorded files.
55
"""
66

7+
import base64
78
import hashlib
89
import logging
910
from pathlib import Path
@@ -120,8 +121,6 @@ def create_audio_input_event(
120121
Returns:
121122
BidiAudioInputEvent dict ready for agent.send().
122123
"""
123-
import base64
124-
125124
# Convert bytes to base64 string for JSON compatibility
126125
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
127126

0 commit comments

Comments
 (0)