Skip to content

Commit 79826e5

Browse files
committed
fix(live): forward thinking config
1 parent 9670ce2 commit 79826e5

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
453453
' backend. Please use Vertex AI backend.'
454454
)
455455
llm_request.live_connect_config.tools = llm_request.config.tools
456+
if llm_request.config.thinking_config is not None:
457+
llm_request.live_connect_config.thinking_config = (
458+
llm_request.config.thinking_config
459+
)
456460
logger.debug('Connecting to live with llm_request:%s', llm_request)
457461
logger.debug('Live connect config: %s', llm_request.live_connect_config)
458462
async with self._live_api_client.aio.live.connect(

tests/unittests/models/test_google_llm.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716
import sys
1817
from typing import Optional
1918
from unittest import mock
@@ -732,16 +731,31 @@ def test_live_api_version_gemini_api(gemini_llm):
732731
assert gemini_llm._live_api_version == "v1alpha"
733732

734733

735-
def test_live_api_client_uses_api_version_from_google_base_url():
734+
@pytest.mark.parametrize(
735+
"base_url, expected_base_url",
736+
[
737+
(
738+
"https://generativelanguage.googleapis.com/v1alpha",
739+
"https://generativelanguage.googleapis.com/",
740+
),
741+
(
742+
"https://generativelanguage.mtls.googleapis.com/v1alpha",
743+
"https://generativelanguage.mtls.googleapis.com/",
744+
),
745+
],
746+
)
747+
def test_live_api_client_uses_api_version_from_google_base_url(
748+
base_url, expected_base_url
749+
):
736750
gemini_llm = Gemini(
737751
model="gemini-2.5-flash",
738-
base_url="https://generativelanguage.googleapis.com/v1alpha",
752+
base_url=base_url,
739753
)
740754

741755
client = gemini_llm._live_api_client
742756
http_options = client._api_client._http_options
743757

744-
assert http_options.base_url == "https://generativelanguage.googleapis.com/"
758+
assert http_options.base_url == expected_base_url
745759
assert http_options.api_version == "v1alpha"
746760

747761

@@ -832,7 +846,7 @@ async def __aexit__(self, *args):
832846
with mock.patch(
833847
"google.adk.models.google_llm.GeminiLlmConnection"
834848
) as MockGeminiLlmConnection:
835-
async with gemini_llm.connect(llm_request) as connection:
849+
async with gemini_llm.connect(llm_request):
836850
# Verify that the connect method was called with the right config
837851
mock_live_client.aio.live.connect.assert_called_once()
838852
call_args = mock_live_client.aio.live.connect.call_args
@@ -852,6 +866,36 @@ async def __aexit__(self, *args):
852866
)
853867

854868

869+
@pytest.mark.asyncio
870+
async def test_connect_forwards_thinking_config(gemini_llm, llm_request):
871+
"""Test that live sessions keep the request thinking_config."""
872+
thinking_config = types.ThinkingConfig(thinking_budget=128)
873+
llm_request.config.thinking_config = thinking_config
874+
llm_request.live_connect_config = types.LiveConnectConfig()
875+
876+
mock_live_session = mock.AsyncMock()
877+
878+
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
879+
880+
class MockLiveConnect:
881+
882+
async def __aenter__(self):
883+
return mock_live_session
884+
885+
async def __aexit__(self, *args):
886+
pass
887+
888+
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
889+
890+
async with gemini_llm.connect(llm_request) as connection:
891+
mock_live_client.aio.live.connect.assert_called_once()
892+
call_args = mock_live_client.aio.live.connect.call_args
893+
config_arg = call_args.kwargs["config"]
894+
895+
assert config_arg.thinking_config == thinking_config
896+
assert isinstance(connection, GeminiLlmConnection)
897+
898+
855899
@pytest.mark.parametrize(
856900
(
857901
"api_backend, "

0 commit comments

Comments
 (0)