Skip to content

Commit 5b8b268

Browse files
committed
fix: avoid live music API key in websocket URL
1 parent 84822a1 commit 5b8b268

2 files changed

Lines changed: 41 additions & 2 deletions

File tree

google/genai/live_music.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,8 @@ async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
175175
transformed_model = t.t_model(self._api_client, model)
176176

177177
if self._api_client.api_key:
178-
api_key = self._api_client.api_key
179178
version = self._api_client._http_options.api_version
180-
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
179+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic'
181180
headers = self._api_client._http_options.headers
182181

183182
# Only mldev supported

google/genai/tests/live/test_live_music.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,29 @@ async def _test_connect():
133133
return await _test_connect()
134134

135135

136+
async def get_connect_call(api_client, model):
137+
mock_ws = AsyncMock()
138+
mock_ws.send = AsyncMock()
139+
mock_ws.recv = AsyncMock(return_value=b'some response')
140+
141+
@contextlib.asynccontextmanager
142+
async def mock_connect(uri, additional_headers=None):
143+
mock_connect.call_args = (uri, additional_headers)
144+
yield mock_ws
145+
146+
mock_connect.call_args = None
147+
148+
@patch.object(live_music, 'connect', new=mock_connect)
149+
async def _test_connect():
150+
live_module = live.AsyncLive(api_client)
151+
async with live_module.music.connect(model=model):
152+
pass
153+
154+
return mock_connect.call_args
155+
156+
return await _test_connect()
157+
158+
136159
def test_mldev_from_env(monkeypatch):
137160
api_key = 'google_api_key'
138161
monkeypatch.setenv('GOOGLE_API_KEY', api_key)
@@ -168,6 +191,23 @@ def test_websocket_base_url():
168191
assert api_client._websocket_base_url() == 'wss://test.com'
169192

170193

194+
@pytest.mark.asyncio
195+
async def test_connect_uses_api_key_header_not_url_query():
196+
api_client = gl_client.BaseApiClient(
197+
api_key='TEST_API_KEY',
198+
http_options={'base_url': 'https://test.com', 'headers': {}},
199+
)
200+
201+
uri, headers = await get_connect_call(api_client, model='lyria-realtime-exp')
202+
203+
assert uri == (
204+
'wss://test.com/ws/google.ai.generativelanguage.v1beta.'
205+
'GenerativeService.BidiGenerateMusic'
206+
)
207+
assert 'key=' not in uri
208+
assert headers['x-goog-api-key'] == 'TEST_API_KEY'
209+
210+
171211
@pytest.mark.parametrize('vertexai', [True, False])
172212
@pytest.mark.asyncio
173213
async def test_async_session_send_weighted_prompts(

0 commit comments

Comments
 (0)