Skip to content

Commit fafafb3

Browse files
authored
fix(models): Default grounding metadata for Gemini 3.1 live (#6018)
1 parent aafd97f commit fafafb3

4 files changed

Lines changed: 181 additions & 23 deletions

File tree

src/google/adk/models/gemini_llm_connection.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __init__(
5050
self._output_transcription_text: str = ''
5151
self._api_backend = api_backend
5252
self._model_version = model_version
53+
self._is_gemini_3_1_flash_live = model_name_utils.is_gemini_3_1_flash_live(
54+
model_version
55+
)
5356

5457
async def send_history(self, history: list[types.Content]):
5558
"""Sends the conversation history to the gemini model.
@@ -80,14 +83,14 @@ async def send_history(self, history: list[types.Content]):
8083
]
8184

8285
if contents:
83-
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
84-
self._model_version
85-
)
8686
# Gemini Enterprise Agent Platform does not support history_config in the SDK.
8787
# To initialize a live session with prior history without hitting a 1007
8888
# protocol error (invalid role mid-session), we consolidate previous multi-turn
8989
# interactions into a unified contextual preamble on a single user role turn.
90-
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
90+
if (
91+
self._is_gemini_3_1_flash_live
92+
and self._api_backend != GoogleLLMVariant.GEMINI_API
93+
):
9194
collapsed_text = 'Previous conversation history:\n'
9295
for c in contents:
9396
text_parts = ''.join(p.text for p in c.parts if p.text)
@@ -101,7 +104,9 @@ async def send_history(self, history: list[types.Content]):
101104
logger.debug('Sending history to live connection: %s', contents)
102105
await self._gemini_session.send_client_content(
103106
turns=contents,
104-
turn_complete=True if is_gemini_31 else (contents[-1].role == 'user'),
107+
turn_complete=True
108+
if self._is_gemini_3_1_flash_live
109+
else (contents[-1].role == 'user'),
105110
)
106111
else:
107112
logger.info('no content is sent')
@@ -126,10 +131,11 @@ async def send_content(self, content: types.Content):
126131
)
127132
else:
128133
logger.debug('Sending LLM new content %s', content)
129-
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
130-
self._model_version
131-
)
132-
if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text:
134+
if (
135+
self._is_gemini_3_1_flash_live
136+
and len(content.parts) == 1
137+
and content.parts[0].text
138+
):
133139
logger.debug('Using send_realtime_input for Gemini 3.1 text input')
134140
await self._gemini_session.send_realtime_input(
135141
text=content.parts[0].text
@@ -151,10 +157,7 @@ async def send_realtime(self, input: RealtimeInput):
151157
if isinstance(input, types.Blob):
152158
# The blob is binary and is very large. So let's not log it.
153159
logger.debug('Sending LLM Blob.')
154-
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
155-
self._model_version
156-
)
157-
if is_gemini_31:
160+
if self._is_gemini_3_1_flash_live:
158161
if input.mime_type and input.mime_type.startswith('audio/'):
159162
await self._gemini_session.send_realtime_input(audio=input)
160163
elif input.mime_type and input.mime_type.startswith('image/'):
@@ -196,10 +199,15 @@ def __build_full_text_response(
196199
Returns:
197200
An LlmResponse containing the full text.
198201
"""
202+
part = types.Part.from_text(text=text)
203+
if is_thought:
204+
part.thought = True
205+
if grounding_metadata is None and self._is_gemini_3_1_flash_live:
206+
grounding_metadata = types.GroundingMetadata()
199207
return LlmResponse(
200208
content=types.Content(
201209
role='model',
202-
parts=[types.Part.from_text(text=text)],
210+
parts=[part],
203211
),
204212
grounding_metadata=grounding_metadata,
205213
partial=False,
@@ -214,6 +222,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
214222
"""
215223

216224
text = ''
225+
is_thought = False
217226
tool_call_parts = []
218227
pending_grounding_metadata = None
219228
async with Aclosing(self._gemini_session.receive()) as agen:
@@ -265,9 +274,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
265274
# grounding_metadata is yielded again at turn_complete,
266275
# so avoid duplicating it here if turn_complete is true.
267276
if not message.server_content.turn_complete:
268-
llm_response.grounding_metadata = (
269-
message.server_content.grounding_metadata
270-
)
277+
if message.server_content.grounding_metadata is not None:
278+
llm_response.grounding_metadata = (
279+
message.server_content.grounding_metadata
280+
)
281+
elif self._is_gemini_3_1_flash_live:
282+
llm_response.grounding_metadata = types.GroundingMetadata()
271283
has_inline_data = any(p.inline_data for p in content.parts)
272284
for part in content.parts:
273285
if part.text:
@@ -394,7 +406,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
394406
turn_complete=True,
395407
interrupted=message.server_content.interrupted,
396408
grounding_metadata=message.server_content.grounding_metadata
397-
or g_metadata_to_yield,
409+
or g_metadata_to_yield
410+
or (
411+
types.GroundingMetadata()
412+
if self._is_gemini_3_1_flash_live
413+
else None
414+
),
398415
model_version=self._model_version,
399416
live_session_id=live_session_id,
400417
turn_complete_reason=getattr(
@@ -430,10 +447,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
430447
# deadlocking the conversation. Other models (e.g. 2.5-pro,
431448
# native-audio) send turn_complete after tool calls, so buffer
432449
# and merge them into a single response at turn_complete.
433-
if (
434-
model_name_utils.is_gemini_3_1_flash_live(self._model_version)
435-
and tool_call_parts
436-
):
450+
if self._is_gemini_3_1_flash_live and tool_call_parts:
437451
logger.debug(
438452
'Yielding tool_call_parts immediately for Gemini 3.1 live tool'
439453
' call'
@@ -442,6 +456,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
442456
content=types.Content(role='model', parts=tool_call_parts),
443457
model_version=self._model_version,
444458
live_session_id=live_session_id,
459+
grounding_metadata=types.GroundingMetadata(),
445460
)
446461
tool_call_parts = []
447462
if message.session_resumption_update:

src/google/adk/utils/model_name_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool:
172172
"""
173173
if not model_string:
174174
return False
175-
return model_string.startswith('gemini-3.1-flash-live')
175+
model_name = extract_model_name(model_string)
176+
return model_name.startswith('gemini-3.1-flash-live')

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,3 +1560,119 @@ async def test_send_history_collapse_vertex_ai(mock_gemini_session):
15601560
mock_gemini_session.send_client_content.call_args.kwargs['turn_complete']
15611561
is True
15621562
)
1563+
1564+
1565+
@pytest.mark.asyncio
1566+
async def test_receive_grounding_metadata_default_gemini_3_1(
1567+
mock_gemini_session,
1568+
):
1569+
"""Verify grounding_metadata defaults to empty GroundingMetadata for Gemini 3.1."""
1570+
conn = GeminiLlmConnection(
1571+
mock_gemini_session,
1572+
model_version='gemini-3.1-flash-live-preview',
1573+
)
1574+
1575+
def make_msg(text=None, tc=False, tool_call=None):
1576+
msg = mock.create_autospec(types.LiveServerMessage, instance=True)
1577+
msg.usage_metadata = None
1578+
msg.tool_call = tool_call
1579+
msg.session_resumption_update = None
1580+
msg.go_away = None
1581+
msg.server_content = mock.Mock()
1582+
msg.server_content.interrupted = False
1583+
msg.server_content.input_transcription = None
1584+
msg.server_content.output_transcription = None
1585+
msg.server_content.generation_complete = False
1586+
msg.server_content.turn_complete = tc
1587+
msg.server_content.grounding_metadata = None
1588+
msg.server_content.model_turn = (
1589+
types.Content(role='model', parts=[types.Part.from_text(text=text)])
1590+
if text
1591+
else None
1592+
)
1593+
return msg
1594+
1595+
# 1. Content event
1596+
msg1 = make_msg(text='hello')
1597+
# 2. Tool call event (yields immediately for Gemini 3.1)
1598+
function_call = types.FunctionCall(name='foo', args={})
1599+
tool_call = mock.create_autospec(types.LiveServerToolCall, instance=True)
1600+
tool_call.function_calls = [function_call]
1601+
msg2 = make_msg(tool_call=tool_call)
1602+
# 3. Turn complete event
1603+
msg3 = make_msg(tc=True)
1604+
1605+
async def mock_receive_generator():
1606+
yield msg1
1607+
yield msg2
1608+
yield msg3
1609+
1610+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
1611+
responses = [resp async for resp in conn.receive()]
1612+
# Expected:
1613+
# responses[0] -> partial content response for msg1 (has grounding_metadata)
1614+
# responses[1] -> full text response for msg1 (has grounding_metadata)
1615+
# responses[2] -> tool call response for msg2 (has grounding_metadata)
1616+
# responses[3] -> turn_complete response for msg3 (has grounding_metadata)
1617+
assert len(responses) == 4
1618+
assert responses[0].content.parts[0].text == 'hello'
1619+
assert isinstance(responses[0].grounding_metadata, types.GroundingMetadata)
1620+
assert responses[0].grounding_metadata.web_search_queries is None
1621+
assert responses[0].partial is True
1622+
assert responses[1].content.parts[0].text == 'hello'
1623+
assert isinstance(responses[1].grounding_metadata, types.GroundingMetadata)
1624+
assert responses[1].partial is False
1625+
assert responses[2].content.parts[0].function_call.name == 'foo'
1626+
assert isinstance(responses[2].grounding_metadata, types.GroundingMetadata)
1627+
assert responses[3].turn_complete is True
1628+
assert isinstance(responses[3].grounding_metadata, types.GroundingMetadata)
1629+
1630+
1631+
@pytest.mark.asyncio
1632+
async def test_receive_grounding_metadata_default_non_gemini_3_1(
1633+
mock_gemini_session,
1634+
):
1635+
"""Verify grounding_metadata stays None for non-Gemini 3.1 models."""
1636+
conn = GeminiLlmConnection(
1637+
mock_gemini_session,
1638+
model_version='gemini-2.5-flash-live',
1639+
)
1640+
1641+
def make_msg(text=None, tc=False):
1642+
msg = mock.create_autospec(types.LiveServerMessage, instance=True)
1643+
msg.usage_metadata = None
1644+
msg.tool_call = None
1645+
msg.session_resumption_update = None
1646+
msg.go_away = None
1647+
msg.server_content = mock.Mock()
1648+
msg.server_content.interrupted = False
1649+
msg.server_content.input_transcription = None
1650+
msg.server_content.output_transcription = None
1651+
msg.server_content.generation_complete = False
1652+
msg.server_content.turn_complete = tc
1653+
msg.server_content.grounding_metadata = None
1654+
msg.server_content.model_turn = (
1655+
types.Content(role='model', parts=[types.Part.from_text(text=text)])
1656+
if text
1657+
else None
1658+
)
1659+
return msg
1660+
1661+
msg1 = make_msg(text='hello')
1662+
msg2 = make_msg(tc=True)
1663+
1664+
async def mock_receive_generator():
1665+
yield msg1
1666+
yield msg2
1667+
1668+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
1669+
responses = [resp async for resp in conn.receive()]
1670+
assert len(responses) == 3
1671+
assert responses[0].content.parts[0].text == 'hello'
1672+
assert responses[0].grounding_metadata is None
1673+
assert responses[0].partial is True
1674+
assert responses[1].content.parts[0].text == 'hello'
1675+
assert responses[1].grounding_metadata is None
1676+
assert responses[1].partial is False
1677+
assert responses[2].turn_complete is True
1678+
assert responses[2].grounding_metadata is None

tests/unittests/utils/test_model_name_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from google.adk.utils.model_name_utils import extract_model_name
1818
from google.adk.utils.model_name_utils import is_gemini_1_model
19+
from google.adk.utils.model_name_utils import is_gemini_3_1_flash_live
1920
from google.adk.utils.model_name_utils import is_gemini_eap_or_2_or_above
2021
from google.adk.utils.model_name_utils import is_gemini_model
2122
from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled
@@ -338,3 +339,28 @@ def test_default_is_disabled(self, monkeypatch):
338339
def test_true_enables_check_bypass(self, monkeypatch):
339340
monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true')
340341
assert is_gemini_model_id_check_disabled() is True
342+
343+
344+
class TestIsGemini31FlashLive:
345+
"""Test the is_gemini_3_1_flash_live function."""
346+
347+
def test_is_gemini_3_1_flash_live_simple_name(self):
348+
"""Test with simple model name format."""
349+
assert is_gemini_3_1_flash_live('gemini-3.1-flash-live') is True
350+
assert is_gemini_3_1_flash_live('gemini-3.1-flash-live-preview') is True
351+
assert is_gemini_3_1_flash_live('gemini-3.1-pro-live') is False
352+
assert is_gemini_3_1_flash_live('gemini-2.5-flash-live') is False
353+
354+
def test_is_gemini_3_1_flash_live_path_based_name(self):
355+
"""Test with path-based format (Vertex AI etc.)."""
356+
vertex_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live'
357+
assert is_gemini_3_1_flash_live(vertex_path) is True
358+
vertex_path_preview = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live-preview'
359+
assert is_gemini_3_1_flash_live(vertex_path_preview) is True
360+
non_live_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash'
361+
assert is_gemini_3_1_flash_live(non_live_path) is False
362+
363+
def test_is_gemini_3_1_flash_live_edge_cases(self):
364+
"""Test edge cases."""
365+
assert is_gemini_3_1_flash_live(None) is False
366+
assert is_gemini_3_1_flash_live('') is False

0 commit comments

Comments
 (0)