Skip to content

Commit d31ccee

Browse files
committed
fix(live): forward thinking config
1 parent 30493ba commit d31ccee

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
455455
' backend. Please use Vertex AI backend.'
456456
)
457457
llm_request.live_connect_config.tools = llm_request.config.tools
458+
if llm_request.config.thinking_config is not None:
459+
llm_request.live_connect_config.thinking_config = (
460+
llm_request.config.thinking_config
461+
)
458462
logger.debug('Connecting to live with llm_request:%s', llm_request)
459463
logger.debug('Live connect config: %s', llm_request.live_connect_config)
460464
async with self._live_api_client.aio.live.connect(

tests/unittests/models/test_google_llm.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716
import sys
1817
from typing import Optional
1918
from unittest import mock
@@ -733,16 +732,31 @@ def test_live_api_version_gemini_api(gemini_llm):
733732
assert gemini_llm._live_api_version == "v1alpha"
734733

735734

736-
def test_live_api_client_uses_api_version_from_google_base_url():
735+
@pytest.mark.parametrize(
736+
"base_url, expected_base_url",
737+
[
738+
(
739+
"https://generativelanguage.googleapis.com/v1alpha",
740+
"https://generativelanguage.googleapis.com/",
741+
),
742+
(
743+
"https://generativelanguage.mtls.googleapis.com/v1alpha",
744+
"https://generativelanguage.mtls.googleapis.com/",
745+
),
746+
],
747+
)
748+
def test_live_api_client_uses_api_version_from_google_base_url(
749+
base_url, expected_base_url
750+
):
737751
gemini_llm = Gemini(
738752
model="gemini-2.5-flash",
739-
base_url="https://generativelanguage.googleapis.com/v1alpha",
753+
base_url=base_url,
740754
)
741755

742756
client = gemini_llm._live_api_client
743757
http_options = client._api_client._http_options
744758

745-
assert http_options.base_url == "https://generativelanguage.googleapis.com/"
759+
assert http_options.base_url == expected_base_url
746760
assert http_options.api_version == "v1alpha"
747761

748762

@@ -833,7 +847,7 @@ async def __aexit__(self, *args):
833847
with mock.patch(
834848
"google.adk.models.google_llm.GeminiLlmConnection"
835849
) as MockGeminiLlmConnection:
836-
async with gemini_llm.connect(llm_request) as connection:
850+
async with gemini_llm.connect(llm_request):
837851
# Verify that the connect method was called with the right config
838852
mock_live_client.aio.live.connect.assert_called_once()
839853
call_args = mock_live_client.aio.live.connect.call_args
@@ -853,6 +867,36 @@ async def __aexit__(self, *args):
853867
)
854868

855869

870+
@pytest.mark.asyncio
871+
async def test_connect_forwards_thinking_config(gemini_llm, llm_request):
872+
"""Test that live sessions keep the request thinking_config."""
873+
thinking_config = types.ThinkingConfig(thinking_budget=128)
874+
llm_request.config.thinking_config = thinking_config
875+
llm_request.live_connect_config = types.LiveConnectConfig()
876+
877+
mock_live_session = mock.AsyncMock()
878+
879+
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
880+
881+
class MockLiveConnect:
882+
883+
async def __aenter__(self):
884+
return mock_live_session
885+
886+
async def __aexit__(self, *args):
887+
pass
888+
889+
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
890+
891+
async with gemini_llm.connect(llm_request) as connection:
892+
mock_live_client.aio.live.connect.assert_called_once()
893+
call_args = mock_live_client.aio.live.connect.call_args
894+
config_arg = call_args.kwargs["config"]
895+
896+
assert config_arg.thinking_config == thinking_config
897+
assert isinstance(connection, GeminiLlmConnection)
898+
899+
856900
@pytest.mark.parametrize(
857901
(
858902
"api_backend, "

0 commit comments

Comments
 (0)