@@ -133,6 +133,29 @@ async def _test_connect():
133133 return await _test_connect ()
134134
135135
136+ async def get_connect_call (api_client , model ):
137+ mock_ws = AsyncMock ()
138+ mock_ws .send = AsyncMock ()
139+ mock_ws .recv = AsyncMock (return_value = b'some response' )
140+
141+ @contextlib .asynccontextmanager
142+ async def mock_connect (uri , additional_headers = None ):
143+ mock_connect .call_args = (uri , additional_headers )
144+ yield mock_ws
145+
146+ mock_connect .call_args = None
147+
148+ @patch .object (live_music , 'connect' , new = mock_connect )
149+ async def _test_connect ():
150+ live_module = live .AsyncLive (api_client )
151+ async with live_module .music .connect (model = model ):
152+ pass
153+
154+ return mock_connect .call_args
155+
156+ return await _test_connect ()
157+
158+
136159def test_mldev_from_env (monkeypatch ):
137160 api_key = 'google_api_key'
138161 monkeypatch .setenv ('GOOGLE_API_KEY' , api_key )
@@ -168,6 +191,23 @@ def test_websocket_base_url():
168191 assert api_client ._websocket_base_url () == 'wss://test.com'
169192
170193
194+ @pytest .mark .asyncio
195+ async def test_connect_uses_api_key_header_not_url_query ():
196+ api_client = gl_client .BaseApiClient (
197+ api_key = 'TEST_API_KEY' ,
198+ http_options = {'base_url' : 'https://test.com' , 'headers' : {}},
199+ )
200+
201+ uri , headers = await get_connect_call (api_client , model = 'lyria-realtime-exp' )
202+
203+ assert uri == (
204+ 'wss://test.com/ws/google.ai.generativelanguage.v1beta.'
205+ 'GenerativeService.BidiGenerateMusic'
206+ )
207+ assert 'key=' not in uri
208+ assert headers ['x-goog-api-key' ] == 'TEST_API_KEY'
209+
210+
171211@pytest .mark .parametrize ('vertexai' , [True , False ])
172212@pytest .mark .asyncio
173213async def test_async_session_send_weighted_prompts (
0 commit comments