Skip to content

Commit e5af12c

Browse files
sasha-gitgwukath
authored andcommitted
fix(live): Resolve 1007 error and support Gemini 3.1 Flash Live protocol
Resolves protocol handling discrepancies when connecting to Gemini 3.1 Flash Live models: - Sets initial_history_in_client_content=True when seeding conversation history during connection handshake. - Appends turn_complete=True on the final history turn during setup. - Iterates sequentially through all parts of model_turn to unpack multiplexed audio and text responses. - Prunes unsupported proactivity and affective dialogue configurations when assembling LiveConnectConfig. Change-Id: I2d0ff38d8a6eb40ea17b37f65a4ddd093230842c
1 parent 5bebfd4 commit e5af12c

5 files changed

Lines changed: 213 additions & 15 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from ...tools.base_toolset import BaseToolset
5252
from ...tools.tool_context import ToolContext
5353
from ...utils.context_utils import Aclosing
54+
from ...utils import model_name_utils
5455
from .audio_cache_manager import AudioCacheManager
5556
from .functions import build_auth_request_event
5657

@@ -552,6 +553,20 @@ async def run_live(
552553
if session_resumption.transparent is None:
553554
session_resumption.transparent = True
554555

556+
if (
557+
isinstance(llm, Gemini)
558+
and llm._api_backend == GoogleLLMVariant.GEMINI_API
559+
and model_name_utils.is_gemini_3_1_flash_live(llm_request.model)
560+
and llm_request.contents
561+
and not invocation_context.live_session_resumption_handle
562+
):
563+
if llm_request.live_connect_config is None:
564+
llm_request.live_connect_config = types.LiveConnectConfig()
565+
if llm_request.live_connect_config.history_config is None:
566+
llm_request.live_connect_config.history_config = types.HistoryConfig(
567+
initial_history_in_client_content=True
568+
)
569+
555570
logger.info(
556571
'Establishing live connection for agent: %s',
557572
invocation_context.agent.name,

src/google/adk/flows/llm_flows/basic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...agents.invocation_context import InvocationContext
2626
from ...events.event import Event
2727
from ...models.llm_request import LlmRequest
28+
from ...utils import model_name_utils
2829
from ...utils.output_schema_utils import can_use_output_schema_with_tools
2930
from ._base_llm_processor import BaseLlmRequestProcessor
3031

@@ -82,11 +83,13 @@ def _build_basic_request(
8283
llm_request.live_connect_config.realtime_input_config = (
8384
invocation_context.run_config.realtime_input_config
8485
)
86+
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
87+
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
8588
llm_request.live_connect_config.enable_affective_dialog = (
86-
invocation_context.run_config.enable_affective_dialog
89+
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
8790
)
8891
llm_request.live_connect_config.proactivity = (
89-
invocation_context.run_config.proactivity
92+
None if is_gemini_31 else invocation_context.run_config.proactivity
9093
)
9194
llm_request.live_connect_config.session_resumption = (
9295
invocation_context.run_config.session_resumption

src/google/adk/models/gemini_llm_connection.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,24 @@ async def send_history(self, history: list[types.Content]):
8080
]
8181

8282
if contents:
83+
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
84+
self._model_version
85+
)
86+
# Gemini Enterprise Agent Platform does not support history_config in the SDK.
87+
# To initialize a live session with prior history without hitting a 1007
88+
# protocol error (invalid role mid-session), we consolidate previous multi-turn
89+
# interactions into a unified contextual preamble on a single user role turn.
90+
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
91+
collapsed_text = "Previous conversation history:\n"
92+
for c in contents:
93+
text_parts = "".join(p.text for p in c.parts if p.text)
94+
collapsed_text += f'[{c.role}]: {text_parts}\n'
95+
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]
96+
8397
logger.debug('Sending history to live connection: %s', contents)
8498
await self._gemini_session.send_client_content(
8599
turns=contents,
86-
turn_complete=contents[-1].role == 'user',
100+
turn_complete=True if is_gemini_31 else (contents[-1].role == 'user'),
87101
)
88102
else:
89103
logger.info('no content is sent')
@@ -254,18 +268,20 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
254268
llm_response.grounding_metadata = (
255269
message.server_content.grounding_metadata
256270
)
257-
if content.parts[0].text:
258-
current_is_thought = getattr(content.parts[0], 'thought', False)
259-
if text and current_is_thought != is_thought:
260-
yield self.__build_full_text_response(text, is_thought)
261-
text = ''
262-
is_thought = False
263-
264-
text += content.parts[0].text
265-
is_thought = current_is_thought
266-
llm_response.partial = True
271+
has_inline_data = any(p.inline_data for p in content.parts)
272+
for part in content.parts:
273+
if part.text:
274+
current_is_thought = getattr(part, 'thought', False)
275+
if text and current_is_thought != is_thought:
276+
yield self.__build_full_text_response(text, is_thought)
277+
text = ''
278+
is_thought = False
279+
280+
text += part.text
281+
is_thought = current_is_thought
282+
llm_response.partial = True
267283
# don't yield the merged text event when receiving audio data
268-
elif text and not content.parts[0].inline_data:
284+
if text and not any(p.text for p in content.parts) and not has_inline_data:
269285
yield self.__build_full_text_response(text, is_thought)
270286
text = ''
271287
is_thought = False

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from google.adk.events.event import Event
2525
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
2626
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
27-
from google.adk.models.google_llm import Gemini
27+
from google.adk.models.google_llm import Gemini, GoogleLLMVariant
2828
from google.adk.models.llm_request import LlmRequest
2929
from google.adk.models.llm_response import LlmResponse
3030
from google.adk.plugins.base_plugin import BasePlugin
@@ -1386,3 +1386,75 @@ async def mock_receive_2():
13861386
second_call_req = mock_connect.call_args_list[1][0][0]
13871387
session_resump = second_call_req.live_connect_config.session_resumption
13881388
assert session_resump.transparent
1389+
1390+
1391+
@pytest.mark.asyncio
1392+
@pytest.mark.parametrize(
1393+
"api_backend,should_have_history_config",
1394+
[
1395+
(GoogleLLMVariant.GEMINI_API, True),
1396+
(GoogleLLMVariant.VERTEX_AI, False),
1397+
],
1398+
)
1399+
async def test_run_live_history_config_gated_by_backend(
1400+
api_backend, should_have_history_config
1401+
):
1402+
"""Test that run_live only sets history_config for Gemini API backend."""
1403+
1404+
real_model = Gemini(model='gemini-3.1-flash-live-preview')
1405+
mock_connection = mock.AsyncMock()
1406+
1407+
class StopTestError(Exception):
1408+
pass
1409+
1410+
async def mock_receive():
1411+
yield LlmResponse(
1412+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1413+
)
1414+
raise StopTestError('stop')
1415+
1416+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1417+
1418+
agent = Agent(name='test_agent', model=real_model)
1419+
invocation_context = await testing_utils.create_invocation_context(
1420+
agent=agent
1421+
)
1422+
invocation_context.live_request_queue = LiveRequestQueue()
1423+
1424+
flow = BaseLlmFlowForTesting()
1425+
1426+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1427+
async def mock_preprocess(ctx, req):
1428+
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1429+
yield Event(id=Event.new_id(), author='test')
1430+
1431+
with mock.patch.object(
1432+
flow, '_preprocess_async', side_effect=mock_preprocess
1433+
):
1434+
with mock.patch.object(
1435+
Gemini, '_api_backend', new_callable=mock.PropertyMock
1436+
) as mock_backend:
1437+
mock_backend.return_value = api_backend
1438+
with mock.patch(
1439+
'google.adk.models.google_llm.Gemini.connect'
1440+
) as mock_connect:
1441+
mock_connect.return_value.__aenter__.return_value = mock_connection
1442+
1443+
try:
1444+
async for _ in flow.run_live(invocation_context):
1445+
pass
1446+
except StopTestError:
1447+
pass
1448+
1449+
assert mock_connect.call_count == 1
1450+
called_req = mock_connect.call_args[0][0]
1451+
if should_have_history_config:
1452+
assert called_req.live_connect_config is not None
1453+
assert called_req.live_connect_config.history_config is not None
1454+
assert (
1455+
called_req.live_connect_config.history_config.initial_history_in_client_content
1456+
is True
1457+
)
1458+
else:
1459+
if called_req.live_connect_config:
1460+
assert called_req.live_connect_config.history_config is None

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,3 +1540,95 @@ async def mock_receive_generator():
15401540
responses[0].turn_complete_reason
15411541
== types.TurnCompleteReason.RESPONSE_REJECTED
15421542
)
1543+
1544+
1545+
@pytest.mark.asyncio
1546+
async def test_receive_multiplexed_parts(gemini_connection, mock_gemini_session):
1547+
"""Test receive with multiplexed inline data and text content."""
1548+
mock_content = types.Content(
1549+
role='model',
1550+
parts=[
1551+
types.Part(
1552+
inline_data=types.Blob(data=b'audio_data', mime_type='audio/pcm')
1553+
),
1554+
types.Part.from_text(text='transcription text'),
1555+
],
1556+
)
1557+
mock_server_content = mock.Mock()
1558+
mock_server_content.model_turn = mock_content
1559+
mock_server_content.interrupted = False
1560+
mock_server_content.input_transcription = None
1561+
mock_server_content.output_transcription = None
1562+
mock_server_content.turn_complete = False
1563+
mock_server_content.grounding_metadata = None
1564+
1565+
mock_message = mock.AsyncMock()
1566+
mock_message.usage_metadata = None
1567+
mock_message.server_content = mock_server_content
1568+
mock_message.tool_call = None
1569+
mock_message.session_resumption_update = None
1570+
mock_message.go_away = None
1571+
1572+
async def mock_receive_generator():
1573+
yield mock_message
1574+
1575+
receive_mock = mock.Mock(return_value=mock_receive_generator())
1576+
mock_gemini_session.receive = receive_mock
1577+
1578+
responses = [resp async for resp in gemini_connection.receive()]
1579+
1580+
assert responses
1581+
content_response = next((r for r in responses if r.content), None)
1582+
assert content_response is not None
1583+
assert content_response.content == mock_content
1584+
assert content_response.partial is True
1585+
1586+
1587+
@pytest.mark.asyncio
1588+
async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
1589+
"""Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True."""
1590+
from google.adk.models.google_llm import GoogleLLMVariant
1591+
conn = GeminiLlmConnection(
1592+
mock_gemini_session,
1593+
api_backend=GoogleLLMVariant.GEMINI_API,
1594+
model_version='gemini-3.1-flash-live-preview',
1595+
)
1596+
mock_gemini_session.send_client_content = mock.AsyncMock()
1597+
1598+
mock_contents = [
1599+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1600+
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
1601+
]
1602+
await conn.send_history(mock_contents)
1603+
1604+
mock_gemini_session.send_client_content.assert_called_once_with(
1605+
turns=mock_contents,
1606+
turn_complete=True,
1607+
)
1608+
1609+
1610+
@pytest.mark.asyncio
1611+
async def test_send_history_collapse_vertex_ai(mock_gemini_session):
1612+
"""Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend."""
1613+
from google.adk.models.google_llm import GoogleLLMVariant
1614+
conn = GeminiLlmConnection(
1615+
mock_gemini_session,
1616+
api_backend=GoogleLLMVariant.VERTEX_AI,
1617+
model_version='gemini-3.1-flash-live-preview',
1618+
)
1619+
mock_gemini_session.send_client_content = mock.AsyncMock()
1620+
1621+
mock_contents = [
1622+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1623+
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
1624+
]
1625+
await conn.send_history(mock_contents)
1626+
1627+
assert mock_gemini_session.send_client_content.call_count == 1
1628+
called_turns = mock_gemini_session.send_client_content.call_args.kwargs['turns']
1629+
assert len(called_turns) == 1
1630+
assert called_turns[0].role == 'user'
1631+
assert 'Previous conversation history:' in called_turns[0].parts[0].text
1632+
assert '[user]: hi' in called_turns[0].parts[0].text
1633+
assert '[model]: hello' in called_turns[0].parts[0].text
1634+
assert mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] is True

0 commit comments

Comments
 (0)