Skip to content

Commit e896c62

Browse files
committed
fix(models): Prevent grounding metadata loss in Gemini 3.1
In Gemini 3.1 Live, grounding metadata is often emitted early alongside partial chunks or immediately after tool execution, causing the runner to discard it because it is associated with partial=True events. - Retain pending grounding metadata during the turn in GeminiLlmConnection.receive and yield it on the final non-partial event. - Add test_receive_grounding_metadata_pending unit test to verify the fix. Change-Id: I4d8ab1f231a3fed63c4375cbfcd8831d6535721e
1 parent 4366cca commit e896c62

2 files changed

Lines changed: 87 additions & 3 deletions

File tree

src/google/adk/models/gemini_llm_connection.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ async def send_realtime(self, input: RealtimeInput):
159159
else:
160160
raise ValueError('Unsupported input type: %s' % type(input))
161161

162-
def __build_full_text_response(self, text: str, is_thought: bool = False):
162+
def __build_full_text_response(
163+
self,
164+
text: str,
165+
is_thought: bool = False,
166+
grounding_metadata: types.GroundingMetadata | None = None,
167+
):
163168
"""Builds a full text response.
164169
165170
The text should not be partial and the returned LlmResponse is not
@@ -168,6 +173,7 @@ def __build_full_text_response(self, text: str, is_thought: bool = False):
168173
Args:
169174
text: The text to be included in the response.
170175
is_thought: Whether the text is a thought.
176+
grounding_metadata: The grounding metadata to include.
171177
172178
Returns:
173179
An LlmResponse containing the full text.
@@ -180,6 +186,7 @@ def __build_full_text_response(self, text: str, is_thought: bool = False):
180186
role='model',
181187
parts=[part],
182188
),
189+
grounding_metadata=grounding_metadata,
183190
partial=False,
184191
live_session_id=self._gemini_session.session_id,
185192
)
@@ -194,6 +201,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
194201
text = ''
195202
is_thought = False
196203
tool_call_parts = []
204+
pending_grounding_metadata = None
197205
async with Aclosing(self._gemini_session.receive()) as agen:
198206
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
199207
# partial content and emit responses as needed.
@@ -209,6 +217,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
209217
)
210218
if message.server_content:
211219
content = message.server_content.model_turn
220+
if message.server_content.grounding_metadata:
221+
pending_grounding_metadata = (
222+
message.server_content.grounding_metadata
223+
)
212224

213225
# Standalone grounding_metadata event (when content is empty)
214226
if (
@@ -338,10 +350,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
338350
)
339351
self._output_transcription_text = ''
340352
if message.server_content.turn_complete:
353+
g_metadata_to_yield = pending_grounding_metadata
341354
if text:
342-
yield self.__build_full_text_response(text, is_thought)
355+
yield self.__build_full_text_response(
356+
text, is_thought, g_metadata_to_yield
357+
)
343358
text = ''
344359
is_thought = False
360+
g_metadata_to_yield = None
345361
if tool_call_parts:
346362
logger.debug('Returning aggregated tool_call_parts')
347363
yield LlmResponse(
@@ -353,7 +369,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
353369
yield LlmResponse(
354370
turn_complete=True,
355371
interrupted=message.server_content.interrupted,
356-
grounding_metadata=message.server_content.grounding_metadata,
372+
grounding_metadata=message.server_content.grounding_metadata
373+
or g_metadata_to_yield,
357374
model_version=self._model_version,
358375
live_session_id=live_session_id,
359376
)

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,3 +1343,70 @@ async def mock_receive_generator():
13431343
content_response = next((r for r in responses if r.content), None)
13441344
assert content_response is not None
13451345
assert content_response.content == mock_content
1346+
1347+
1348+
@pytest.mark.asyncio
1349+
async def test_receive_grounding_metadata_pending(
1350+
gemini_connection, mock_gemini_session
1351+
):
1352+
"""Test that grounding metadata in partial chunks is pending and yielded on full text."""
1353+
grounding_metadata = types.GroundingMetadata(
1354+
web_search_queries=['stock price of google'],
1355+
)
1356+
1357+
def make_msg(text=None, g_meta=None, tc=False):
1358+
msg = mock.Mock(
1359+
usage_metadata=None,
1360+
tool_call=None,
1361+
session_resumption_update=None,
1362+
go_away=None,
1363+
)
1364+
msg.server_content = mock.Mock(
1365+
interrupted=False,
1366+
input_transcription=None,
1367+
output_transcription=None,
1368+
generation_complete=False,
1369+
turn_complete=tc,
1370+
grounding_metadata=g_meta,
1371+
model_turn=types.Content(
1372+
role='model', parts=[types.Part.from_text(text=text)]
1373+
)
1374+
if text
1375+
else None,
1376+
)
1377+
return msg
1378+
1379+
msg1 = make_msg(text='hello', g_meta=grounding_metadata)
1380+
msg2 = make_msg(text=' world')
1381+
msg3 = make_msg(tc=True)
1382+
1383+
async def gen():
1384+
yield msg1
1385+
yield msg2
1386+
yield msg3
1387+
1388+
mock_gemini_session.receive = mock.Mock(return_value=gen())
1389+
1390+
responses = [resp async for resp in gemini_connection.receive()]
1391+
1392+
# Expected responses:
1393+
# 1. Msg 1 partial (hello) with grounding_metadata
1394+
# 2. Msg 2 partial ( world) without grounding_metadata
1395+
# 3. Full text response (hello world) with PENDING grounding_metadata
1396+
# 4. Turn complete response without grounding_metadata (already cleared)
1397+
assert len(responses) == 4
1398+
1399+
assert responses[0].content.parts[0].text == 'hello'
1400+
assert responses[0].partial is True
1401+
assert responses[0].grounding_metadata == grounding_metadata
1402+
1403+
assert responses[1].content.parts[0].text == ' world'
1404+
assert responses[1].partial is True
1405+
assert responses[1].grounding_metadata is None
1406+
1407+
assert responses[2].content.parts[0].text == 'hello world'
1408+
assert responses[2].partial is False
1409+
assert responses[2].grounding_metadata == grounding_metadata
1410+
1411+
assert responses[3].turn_complete is True
1412+
assert responses[3].grounding_metadata is None

0 commit comments

Comments
 (0)