Skip to content

Commit 272a1fc

Browse files
author
Murat Kaan Meral
committed
Update model unit tests to use new class names
Updated test imports and usages: - GeminiLiveModel → BidiGeminiLiveModel - NovaSonicModel → BidiNovaSonicModel - OpenAIRealtimeModel → BidiOpenAIRealtimeModel Note: 21 model tests still failing because they call .connect() but models now use .start(). This is a pre-existing issue that needs separate fix - tests need API update.
1 parent 873441b commit 272a1fc

3 files changed

Lines changed: 29 additions & 29 deletions

File tree

tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for Gemini Live bidirectional streaming model.
22
3-
Tests the unified GeminiLiveModel interface including:
3+
Tests the unified BidiGeminiLiveModel interface including:
44
- Model initialization and configuration
55
- Connection establishment and lifecycle
66
- Unified send() method with different content types
@@ -13,7 +13,7 @@
1313
from google import genai
1414
from google.genai import types as genai_types
1515

16-
from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel
16+
from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel
1717
from strands.experimental.bidirectional_streaming.types.events import (
1818
BidiAudioInputEvent,
1919
BidiImageInputEvent,
@@ -56,9 +56,9 @@ def api_key():
5656

5757
@pytest.fixture
5858
def model(mock_genai_client, model_id, api_key):
59-
"""Create a GeminiLiveModel instance."""
59+
"""Create a BidiGeminiLiveModel instance."""
6060
_ = mock_genai_client
61-
return GeminiLiveModel(model_id=model_id, api_key=api_key)
61+
return BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
6262

6363

6464
@pytest.fixture
@@ -88,20 +88,20 @@ def test_model_initialization(mock_genai_client, model_id, api_key):
8888
_ = mock_genai_client
8989

9090
# Test default config
91-
model_default = GeminiLiveModel()
91+
model_default = BidiGeminiLiveModel()
9292
assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09"
9393
assert model_default.api_key is None
9494
assert model_default._active is False
9595
assert model_default.live_session is None
9696

9797
# Test with API key
98-
model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key)
98+
model_with_key = BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
9999
assert model_with_key.model_id == model_id
100100
assert model_with_key.api_key == api_key
101101

102102
# Test with custom config
103103
live_config = {"temperature": 0.7, "top_p": 0.9}
104-
model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config)
104+
model_custom = BidiGeminiLiveModel(model_id=model_id, live_config=live_config)
105105
assert model_custom.live_config == live_config
106106

107107

@@ -152,7 +152,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id):
152152
mock_client, _, mock_live_session_cm = mock_genai_client
153153

154154
# Test connection error
155-
model1 = GeminiLiveModel(model_id=model_id, api_key=api_key)
155+
model1 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
156156
mock_client.aio.live.connect.side_effect = Exception("Connection failed")
157157
with pytest.raises(Exception, match="Connection failed"):
158158
await model1.connect()
@@ -161,18 +161,18 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id):
161161
mock_client.aio.live.connect.side_effect = None
162162

163163
# Test double connection
164-
model2 = GeminiLiveModel(model_id=model_id, api_key=api_key)
164+
model2 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
165165
await model2.connect()
166166
with pytest.raises(RuntimeError, match="Connection already active"):
167167
await model2.connect()
168168
await model2.close()
169169

170170
# Test close when not connected
171-
model3 = GeminiLiveModel(model_id=model_id, api_key=api_key)
171+
model3 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
172172
await model3.close() # Should not raise
173173

174174
# Test close error handling
175-
model4 = GeminiLiveModel(model_id=model_id, api_key=api_key)
175+
model4 = BidiGeminiLiveModel(model_id=model_id, api_key=api_key)
176176
await model4.connect()
177177
mock_live_session_cm.__aexit__.side_effect = Exception("Close failed")
178178
with pytest.raises(Exception, match="Close failed"):

tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest_asyncio
1414

1515
from strands.experimental.bidirectional_streaming.models.novasonic import (
16-
NovaSonicModel,
16+
BidiNovaSonicModel,
1717
)
1818
from strands.types.tools import ToolResult
1919

@@ -53,7 +53,7 @@ def mock_client(mock_stream):
5353
@pytest_asyncio.fixture
5454
async def nova_model(model_id, region):
5555
"""Create Nova Sonic model instance."""
56-
model = NovaSonicModel(model_id=model_id, region=region)
56+
model = BidiNovaSonicModel(model_id=model_id, region=region)
5757
yield model
5858
# Cleanup
5959
if model._active:
@@ -66,7 +66,7 @@ async def nova_model(model_id, region):
6666
@pytest.mark.asyncio
6767
async def test_model_initialization(model_id, region):
6868
"""Test model initialization with configuration."""
69-
model = NovaSonicModel(model_id=model_id, region=region)
69+
model = BidiNovaSonicModel(model_id=model_id, region=region)
7070

7171
assert model.model_id == model_id
7272
assert model.region == region
@@ -120,7 +120,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model
120120
await nova_model.close()
121121

122122
# Test close when already closed
123-
model2 = NovaSonicModel(model_id=model_id, region=region)
123+
model2 = BidiNovaSonicModel(model_id=model_id, region=region)
124124
await model2.close() # Should not raise
125125
await model2.close() # Second call should also be safe
126126

tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for OpenAI Realtime bidirectional streaming model.
22
3-
Tests the unified OpenAIRealtimeModel interface including:
3+
Tests the unified BidiOpenAIRealtimeModel interface including:
44
- Model initialization and configuration
55
- Connection establishment with WebSocket
66
- Unified send() method with different content types
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717

18-
from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel
18+
from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel
1919
from strands.experimental.bidirectional_streaming.types.events import (
2020
BidiAudioInputEvent,
2121
BidiImageInputEvent,
@@ -56,8 +56,8 @@ def api_key():
5656

5757
@pytest.fixture
5858
def model(api_key, model_name):
59-
"""Create an OpenAIRealtimeModel instance."""
60-
return OpenAIRealtimeModel(model=model_name, api_key=api_key)
59+
"""Create an BidiOpenAIRealtimeModel instance."""
60+
return BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
6161

6262

6363
@pytest.fixture
@@ -85,19 +85,19 @@ def messages():
8585
def test_model_initialization(api_key, model_name):
8686
"""Test model initialization with various configurations."""
8787
# Test default config
88-
model_default = OpenAIRealtimeModel(api_key="test-key")
88+
model_default = BidiOpenAIRealtimeModel(api_key="test-key")
8989
assert model_default.model == "gpt-realtime"
9090
assert model_default.api_key == "test-key"
9191
assert model_default._active is False
9292
assert model_default.websocket is None
9393

9494
# Test with custom model
95-
model_custom = OpenAIRealtimeModel(model=model_name, api_key=api_key)
95+
model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
9696
assert model_custom.model == model_name
9797
assert model_custom.api_key == api_key
9898

9999
# Test with organization and project
100-
model_org = OpenAIRealtimeModel(
100+
model_org = BidiOpenAIRealtimeModel(
101101
model=model_name,
102102
api_key=api_key,
103103
organization="org-123",
@@ -108,15 +108,15 @@ def test_model_initialization(api_key, model_name):
108108

109109
# Test with env API key
110110
with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}):
111-
model_env = OpenAIRealtimeModel()
111+
model_env = BidiOpenAIRealtimeModel()
112112
assert model_env.api_key == "env-key"
113113

114114

115115
def test_init_without_api_key_raises():
116116
"""Test that initialization without API key raises error."""
117117
with unittest.mock.patch.dict("os.environ", {}, clear=True):
118118
with pytest.raises(ValueError, match="OpenAI API key is required"):
119-
OpenAIRealtimeModel()
119+
BidiOpenAIRealtimeModel()
120120

121121

122122
# Connection Tests
@@ -171,7 +171,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp
171171
await model.close()
172172

173173
# Test connection with organization header
174-
model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123")
174+
model_org = BidiOpenAIRealtimeModel(api_key="test-key", organization="org-123")
175175
await model_org.connect()
176176
call_kwargs = mock_connect.call_args.kwargs
177177
headers = call_kwargs.get("additional_headers", [])
@@ -187,7 +187,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam
187187
mock_connect, mock_ws = mock_websockets_connect
188188

189189
# Test connection error
190-
model1 = OpenAIRealtimeModel(model=model_name, api_key=api_key)
190+
model1 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
191191
mock_connect.side_effect = Exception("Connection failed")
192192
with pytest.raises(Exception, match="Connection failed"):
193193
await model1.connect()
@@ -198,18 +198,18 @@ async def async_connect(*args, **kwargs):
198198
mock_connect.side_effect = async_connect
199199

200200
# Test double connection
201-
model2 = OpenAIRealtimeModel(model=model_name, api_key=api_key)
201+
model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
202202
await model2.connect()
203203
with pytest.raises(RuntimeError, match="Connection already active"):
204204
await model2.connect()
205205
await model2.close()
206206

207207
# Test close when not connected
208-
model3 = OpenAIRealtimeModel(model=model_name, api_key=api_key)
208+
model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
209209
await model3.close() # Should not raise
210210

211211
# Test close error handling (should not raise, just log)
212-
model4 = OpenAIRealtimeModel(model=model_name, api_key=api_key)
212+
model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
213213
await model4.connect()
214214
mock_ws.close.side_effect = Exception("Close failed")
215215
await model4.close() # Should not raise

0 commit comments

Comments
 (0)