Skip to content

Commit 6d027b4

Browse files
authored
fix: add missing Gemini imports in base_llm_flow (#5943)
1 parent 19a87ca commit 6d027b4

5 files changed

Lines changed: 75 additions & 41 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from ...auth.auth_tool import AuthConfig
4040
from ...events.event import Event
4141
from ...models.base_llm_connection import BaseLlmConnection
42+
from ...models.google_llm import Gemini
43+
from ...models.google_llm import GoogleLLMVariant
4244
from ...models.llm_request import LlmRequest
4345
from ...models.llm_response import LlmResponse
4446
from ...telemetry import tracing
@@ -47,8 +49,8 @@
4749
from ...telemetry.tracing import tracer
4850
from ...tools.base_toolset import BaseToolset
4951
from ...tools.tool_context import ToolContext
50-
from ...utils.context_utils import Aclosing
5152
from ...utils import model_name_utils
53+
from ...utils.context_utils import Aclosing
5254
from .audio_cache_manager import AudioCacheManager
5355
from .functions import build_auth_request_event
5456

@@ -515,7 +517,17 @@ async def run_live(
515517
llm_request.live_connect_config.session_resumption.handle = (
516518
invocation_context.live_session_resumption_handle
517519
)
518-
llm_request.live_connect_config.session_resumption.transparent = True
520+
# Only set transparent=True for Vertex AI backend, as the Gemini API
521+
# backend explicitly rejects it.
522+
if (
523+
isinstance(llm, Gemini)
524+
and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access
525+
):
526+
session_resumption = (
527+
llm_request.live_connect_config.session_resumption
528+
)
529+
if session_resumption.transparent is None:
530+
session_resumption.transparent = True
519531

520532
if (
521533
isinstance(llm, Gemini)
@@ -527,8 +539,8 @@ async def run_live(
527539
if llm_request.live_connect_config is None:
528540
llm_request.live_connect_config = types.LiveConnectConfig()
529541
if llm_request.live_connect_config.history_config is None:
530-
llm_request.live_connect_config.history_config = types.HistoryConfig(
531-
initial_history_in_client_content=True
542+
llm_request.live_connect_config.history_config = (
543+
types.HistoryConfig(initial_history_in_client_content=True)
532544
)
533545

534546
logger.info(

src/google/adk/flows/llm_flows/basic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,15 @@ def _build_basic_request(
7979
llm_request.live_connect_config.realtime_input_config = (
8080
invocation_context.run_config.realtime_input_config
8181
)
82-
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
82+
active_model_name = (
83+
getattr(getattr(agent, 'canonical_live_model', None), 'model', None)
84+
or llm_request.model
85+
)
8386
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
8487
llm_request.live_connect_config.enable_affective_dialog = (
85-
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
88+
None
89+
if is_gemini_31
90+
else invocation_context.run_config.enable_affective_dialog
8691
)
8792
llm_request.live_connect_config.proactivity = (
8893
None if is_gemini_31 else invocation_context.run_config.proactivity

src/google/adk/models/gemini_llm_connection.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ async def send_history(self, history: list[types.Content]):
8888
# protocol error (invalid role mid-session), we consolidate previous multi-turn
8989
# interactions into a unified contextual preamble on a single user role turn.
9090
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
91-
collapsed_text = "Previous conversation history:\n"
91+
collapsed_text = 'Previous conversation history:\n'
9292
for c in contents:
93-
text_parts = "".join(p.text for p in c.parts if p.text)
93+
text_parts = ''.join(p.text for p in c.parts if p.text)
9494
collapsed_text += f'[{c.role}]: {text_parts}\n'
95-
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]
95+
contents = [
96+
types.Content(
97+
role='user', parts=[types.Part.from_text(text=collapsed_text)]
98+
)
99+
]
96100

97101
logger.debug('Sending history to live connection: %s', contents)
98102
await self._gemini_session.send_client_content(
@@ -276,8 +280,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
276280
text += part.text
277281
is_thought = current_is_thought
278282
llm_response.partial = True
279-
# don't yield the merged text event when receiving audio data
280-
if text and not any(p.text for p in content.parts) and not has_inline_data:
283+
if (
284+
text
285+
and not any(p.text for p in content.parts)
286+
and not has_inline_data
287+
):
281288
yield self.__build_full_text_response(text, is_thought)
282289
text = ''
283290
yield llm_response

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717
from unittest import mock
1818
from unittest.mock import AsyncMock
1919

20+
from google.adk.agents.live_request_queue import LiveRequestQueue
2021
from google.adk.agents.llm_agent import Agent
2122
from google.adk.agents.run_config import RunConfig
2223
from google.adk.events.event import Event
2324
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
2425
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
25-
from google.adk.models.google_llm import Gemini, GoogleLLMVariant
26+
from google.adk.models.google_llm import Gemini
27+
from google.adk.models.google_llm import GoogleLLMVariant
2628
from google.adk.models.llm_request import LlmRequest
2729
from google.adk.models.llm_response import LlmResponse
2830
from google.adk.plugins.base_plugin import BasePlugin
2931
from google.adk.tools.base_toolset import BaseToolset
3032
from google.adk.tools.google_search_tool import GoogleSearchTool
3133
from google.genai import types
3234
import pytest
35+
from websockets.exceptions import ConnectionClosed
3336

3437
from ... import testing_utils
3538

@@ -490,8 +493,6 @@ async def call(self, **kwargs):
490493
@pytest.mark.asyncio
491494
async def test_run_live_reconnects_on_connection_closed():
492495
"""Test that run_live reconnects when ConnectionClosed occurs."""
493-
from google.adk.agents.live_request_queue import LiveRequestQueue
494-
from websockets.exceptions import ConnectionClosed
495496

496497
real_model = Gemini()
497498
mock_connection = mock.AsyncMock()
@@ -558,7 +559,6 @@ async def mock_receive_2():
558559
@pytest.mark.asyncio
559560
async def test_run_live_reconnects_on_api_error():
560561
"""Test that run_live reconnects when APIError occurs."""
561-
from google.adk.agents.live_request_queue import LiveRequestQueue
562562
from google.genai.errors import APIError
563563

564564
real_model = Gemini()
@@ -626,7 +626,6 @@ async def mock_receive_2():
626626
@pytest.mark.asyncio
627627
async def test_run_live_skips_send_history_on_resumption():
628628
"""Test that run_live skips send_history when resuming a session."""
629-
from google.adk.agents.live_request_queue import LiveRequestQueue
630629

631630
real_model = Gemini()
632631
mock_connection = mock.AsyncMock()
@@ -684,7 +683,6 @@ async def mock_receive():
684683
@pytest.mark.asyncio
685684
async def test_live_session_resumption_go_away():
686685
"""Test that go_away triggers reconnection."""
687-
from google.adk.agents.live_request_queue import LiveRequestQueue
688686

689687
real_model = Gemini()
690688
mock_connection = mock.AsyncMock()
@@ -743,8 +741,6 @@ async def mock_receive_2():
743741
@pytest.mark.asyncio
744742
async def test_run_live_no_reconnect_without_handle():
745743
"""Test that run_live does not reconnect when handle is missing."""
746-
from google.adk.agents.live_request_queue import LiveRequestQueue
747-
from websockets.exceptions import ConnectionClosed
748744

749745
real_model = Gemini()
750746
mock_connection = mock.AsyncMock()
@@ -786,8 +782,6 @@ async def mock_receive():
786782
@pytest.mark.asyncio
787783
async def test_run_live_reconnect_limit():
788784
"""Test that run_live stops reconnecting after 5 attempts."""
789-
from google.adk.agents.live_request_queue import LiveRequestQueue
790-
from websockets.exceptions import ConnectionClosed
791785

792786
real_model = Gemini()
793787

@@ -843,9 +837,7 @@ async def mock_receive():
843837
@pytest.mark.asyncio
844838
async def test_run_live_reconnect_reset_attempt():
845839
"""Test that attempt counter is reset on successful communication."""
846-
from google.adk.agents.live_request_queue import LiveRequestQueue
847840
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
848-
from websockets.exceptions import ConnectionClosed
849841

850842
real_model = Gemini()
851843

@@ -987,7 +979,6 @@ async def mock_receive():
987979
@pytest.mark.asyncio
988980
async def test_run_live_clears_resumption_handle_on_transfer():
989981
"""Test that run_live clears session resumption handles when transferring to another agent."""
990-
from google.adk.agents.live_request_queue import LiveRequestQueue
991982

992983
agent = Agent(name='test_agent')
993984
invocation_context = await testing_utils.create_invocation_context(
@@ -1184,21 +1175,27 @@ async def mock_receive_2():
11841175
mock_aenter = mock.AsyncMock()
11851176
mock_aenter.side_effect = [mock_connection, mock_connection_2]
11861177

1187-
with mock.patch(
1188-
'google.adk.models.google_llm.Gemini.connect'
1189-
) as mock_connect:
1190-
mock_connect.return_value.__aenter__ = mock_aenter
1178+
with mock.patch.object(
1179+
Gemini, '_api_backend', new_callable=mock.PropertyMock
1180+
) as mock_backend:
1181+
mock_backend.return_value = GoogleLLMVariant.GEMINI_API
1182+
with mock.patch(
1183+
'google.adk.models.google_llm.Gemini.connect'
1184+
) as mock_connect:
1185+
mock_connect.return_value.__aenter__ = mock_aenter
11911186

1192-
try:
1193-
async for _ in flow.run_live(invocation_context):
1187+
try:
1188+
async for _ in flow.run_live(invocation_context):
1189+
pass
1190+
except StopTestError:
11941191
pass
1195-
except StopTestError:
1196-
pass
11971192

1198-
assert mock_connect.call_count == 2
1199-
second_call_req = mock_connect.call_args_list[1][0][0]
1200-
session_resump = second_call_req.live_connect_config.session_resumption
1201-
assert session_resump.transparent is None
1193+
assert mock_connect.call_count == 2
1194+
second_call_req = mock_connect.call_args_list[1][0][0]
1195+
session_resump = (
1196+
second_call_req.live_connect_config.session_resumption
1197+
)
1198+
assert session_resump.transparent is None
12021199

12031200

12041201
@pytest.mark.asyncio
@@ -1275,7 +1272,7 @@ async def mock_receive_2():
12751272

12761273
@pytest.mark.asyncio
12771274
@pytest.mark.parametrize(
1278-
"api_backend,should_have_history_config",
1275+
'api_backend,should_have_history_config',
12791276
[
12801277
(GoogleLLMVariant.GEMINI_API, True),
12811278
(GoogleLLMVariant.VERTEX_AI, False),
@@ -1309,8 +1306,12 @@ async def mock_receive():
13091306
flow = BaseLlmFlowForTesting()
13101307

13111308
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1309+
13121310
async def mock_preprocess(ctx, req):
1313-
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1311+
req.model = 'gemini-3.1-flash-live-preview'
1312+
req.contents = [
1313+
types.Content(parts=[types.Part.from_text(text='history')])
1314+
]
13141315
yield Event(id=Event.new_id(), author='test')
13151316

13161317
with mock.patch.object(

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,9 @@ async def mock_receive_generator():
14621462

14631463

14641464
@pytest.mark.asyncio
1465-
async def test_receive_multiplexed_parts(gemini_connection, mock_gemini_session):
1465+
async def test_receive_multiplexed_parts(
1466+
gemini_connection, mock_gemini_session
1467+
):
14661468
"""Test receive with multiplexed inline data and text content."""
14671469
mock_content = types.Content(
14681470
role='model',
@@ -1507,6 +1509,7 @@ async def mock_receive_generator():
15071509
async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
15081510
"""Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True."""
15091511
from google.adk.models.google_llm import GoogleLLMVariant
1512+
15101513
conn = GeminiLlmConnection(
15111514
mock_gemini_session,
15121515
api_backend=GoogleLLMVariant.GEMINI_API,
@@ -1530,6 +1533,7 @@ async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
15301533
async def test_send_history_collapse_vertex_ai(mock_gemini_session):
15311534
"""Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend."""
15321535
from google.adk.models.google_llm import GoogleLLMVariant
1536+
15331537
conn = GeminiLlmConnection(
15341538
mock_gemini_session,
15351539
api_backend=GoogleLLMVariant.VERTEX_AI,
@@ -1544,10 +1548,15 @@ async def test_send_history_collapse_vertex_ai(mock_gemini_session):
15441548
await conn.send_history(mock_contents)
15451549

15461550
assert mock_gemini_session.send_client_content.call_count == 1
1547-
called_turns = mock_gemini_session.send_client_content.call_args.kwargs['turns']
1551+
called_turns = mock_gemini_session.send_client_content.call_args.kwargs[
1552+
'turns'
1553+
]
15481554
assert len(called_turns) == 1
15491555
assert called_turns[0].role == 'user'
15501556
assert 'Previous conversation history:' in called_turns[0].parts[0].text
15511557
assert '[user]: hi' in called_turns[0].parts[0].text
15521558
assert '[model]: hello' in called_turns[0].parts[0].text
1553-
assert mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] is True
1559+
assert (
1560+
mock_gemini_session.send_client_content.call_args.kwargs['turn_complete']
1561+
is True
1562+
)

0 commit comments

Comments
 (0)