Skip to content

Commit c295c81

Browse files
committed
fix(live): forward thinking config
1 parent 4006fe4 commit c295c81

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
453453
' backend. Please use Vertex AI backend.'
454454
)
455455
llm_request.live_connect_config.tools = llm_request.config.tools
456+
if llm_request.config.thinking_config is not None:
457+
llm_request.live_connect_config.thinking_config = (
458+
llm_request.config.thinking_config
459+
)
456460
logger.debug('Connecting to live with llm_request:%s', llm_request)
457461
logger.debug('Live connect config: %s', llm_request.live_connect_config)
458462
async with self._live_api_client.aio.live.connect(

tests/unittests/models/test_google_llm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,36 @@ async def __aexit__(self, *args):
852852
)
853853

854854

855+
@pytest.mark.asyncio
856+
async def test_connect_forwards_thinking_config(gemini_llm, llm_request):
857+
"""Test that live sessions keep the request thinking_config."""
858+
thinking_config = types.ThinkingConfig(thinking_budget=128)
859+
llm_request.config.thinking_config = thinking_config
860+
llm_request.live_connect_config = types.LiveConnectConfig()
861+
862+
mock_live_session = mock.AsyncMock()
863+
864+
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
865+
866+
class MockLiveConnect:
867+
868+
async def __aenter__(self):
869+
return mock_live_session
870+
871+
async def __aexit__(self, *args):
872+
pass
873+
874+
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
875+
876+
async with gemini_llm.connect(llm_request) as connection:
877+
mock_live_client.aio.live.connect.assert_called_once()
878+
call_args = mock_live_client.aio.live.connect.call_args
879+
config_arg = call_args.kwargs["config"]
880+
881+
assert config_arg.thinking_config == thinking_config
882+
assert isinstance(connection, GeminiLlmConnection)
883+
884+
855885
@pytest.mark.parametrize(
856886
(
857887
"api_backend, "

0 commit comments

Comments
 (0)