Skip to content

Commit b0bd310

Browse files
committed
Address Gemini Code Assist review feedback
- Fix critical bug: Remove premature reset of last_grounding_metadata before turn_complete response to prevent data loss - Simplify duplicate reset logic in interrupted handling - Add grounding_metadata propagation to tool_call responses - Add test for grounding_metadata with text content + turn_complete - Add test for grounding_metadata with tool_call responses All 27 tests pass.
1 parent 377a632 commit b0bd310

2 files changed

Lines changed: 110 additions & 5 deletions

File tree

src/google/adk/models/gemini_llm_connection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
294294
if text:
295295
yield self.__build_full_text_response(text, last_grounding_metadata)
296296
text = ''
297-
last_grounding_metadata = None
298297
yield LlmResponse(
299298
turn_complete=True,
300299
interrupted=message.server_content.interrupted,
@@ -310,23 +309,25 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
310309
if text:
311310
yield self.__build_full_text_response(text, last_grounding_metadata)
312311
text = ''
313-
last_grounding_metadata = None
314312
else:
315313
yield LlmResponse(
316314
interrupted=message.server_content.interrupted,
317315
grounding_metadata=last_grounding_metadata,
318316
)
319-
last_grounding_metadata = None # Reset after yielding
317+
last_grounding_metadata = None # Reset after yielding
320318
if message.tool_call:
321319
if text:
322320
yield self.__build_full_text_response(text, last_grounding_metadata)
323321
text = ''
324-
last_grounding_metadata = None
325322
parts = [
326323
types.Part(function_call=function_call)
327324
for function_call in message.tool_call.function_calls
328325
]
329-
yield LlmResponse(content=types.Content(role='model', parts=parts))
326+
yield LlmResponse(
327+
content=types.Content(role='model', parts=parts),
328+
grounding_metadata=last_grounding_metadata,
329+
)
330+
last_grounding_metadata = None # Reset after yielding
330331
if message.session_resumption_update:
331332
logger.debug('Received session resumption message: %s', message)
332333
yield (

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,107 @@ async def mock_receive_generator():
887887
assert turn_complete_response is not None
888888
# The grounding_metadata should be carried over to turn_complete
889889
assert turn_complete_response.grounding_metadata == mock_grounding_metadata
890+
891+
892+
@pytest.mark.asyncio
893+
async def test_receive_grounding_metadata_with_text_and_turn_complete(
894+
gemini_connection, mock_gemini_session
895+
):
896+
"""Test that grounding_metadata is preserved when text content is followed by turn_complete."""
897+
mock_content = types.Content(
898+
role='model', parts=[types.Part.from_text(text='response text')]
899+
)
900+
mock_grounding_metadata = types.GroundingMetadata(
901+
retrieval_queries=['test query'],
902+
)
903+
904+
# Message with both content and grounding, followed by turn_complete
905+
mock_server_content = mock.Mock()
906+
mock_server_content.model_turn = mock_content
907+
mock_server_content.interrupted = False
908+
mock_server_content.input_transcription = None
909+
mock_server_content.output_transcription = None
910+
mock_server_content.turn_complete = True
911+
mock_server_content.generation_complete = False
912+
mock_server_content.grounding_metadata = mock_grounding_metadata
913+
914+
mock_message = mock.Mock()
915+
mock_message.usage_metadata = None
916+
mock_message.server_content = mock_server_content
917+
mock_message.tool_call = None
918+
mock_message.session_resumption_update = None
919+
920+
async def mock_receive_generator():
921+
yield mock_message
922+
923+
receive_mock = mock.Mock(return_value=mock_receive_generator())
924+
mock_gemini_session.receive = receive_mock
925+
926+
responses = [resp async for resp in gemini_connection.receive()]
927+
928+
# Find content response with grounding
929+
content_response = next((r for r in responses if r.content), None)
930+
assert content_response is not None
931+
assert content_response.grounding_metadata == mock_grounding_metadata
932+
933+
# Find turn_complete response - should also have grounding_metadata
934+
turn_complete_response = next((r for r in responses if r.turn_complete), None)
935+
assert turn_complete_response is not None
936+
assert turn_complete_response.grounding_metadata == mock_grounding_metadata
937+
938+
939+
@pytest.mark.asyncio
940+
async def test_receive_grounding_metadata_with_tool_call(
941+
gemini_connection, mock_gemini_session
942+
):
943+
"""Test that grounding_metadata is propagated with tool_call responses."""
944+
mock_grounding_metadata = types.GroundingMetadata(
945+
retrieval_queries=['test query'],
946+
)
947+
948+
# First message with grounding metadata
949+
mock_server_content1 = mock.Mock()
950+
mock_server_content1.model_turn = None
951+
mock_server_content1.interrupted = False
952+
mock_server_content1.input_transcription = None
953+
mock_server_content1.output_transcription = None
954+
mock_server_content1.turn_complete = False
955+
mock_server_content1.generation_complete = False
956+
mock_server_content1.grounding_metadata = mock_grounding_metadata
957+
958+
message1 = mock.Mock()
959+
message1.usage_metadata = None
960+
message1.server_content = mock_server_content1
961+
message1.tool_call = None
962+
message1.session_resumption_update = None
963+
964+
# Second message with tool_call
965+
mock_function_call = types.FunctionCall(
966+
name='test_function', args={'param': 'value'}
967+
)
968+
mock_tool_call = mock.Mock()
969+
mock_tool_call.function_calls = [mock_function_call]
970+
971+
message2 = mock.Mock()
972+
message2.usage_metadata = None
973+
message2.server_content = None
974+
message2.tool_call = mock_tool_call
975+
message2.session_resumption_update = None
976+
977+
async def mock_receive_generator():
978+
yield message1
979+
yield message2
980+
981+
receive_mock = mock.Mock(return_value=mock_receive_generator())
982+
mock_gemini_session.receive = receive_mock
983+
984+
responses = [resp async for resp in gemini_connection.receive()]
985+
986+
# Find tool_call response
987+
tool_call_response = next(
988+
(r for r in responses if r.content and r.content.parts[0].function_call),
989+
None,
990+
)
991+
assert tool_call_response is not None
992+
# The grounding_metadata should be carried over to tool_call
993+
assert tool_call_response.grounding_metadata == mock_grounding_metadata

0 commit comments

Comments
 (0)