Skip to content

Commit 3d43c6e

Browse files
committed
Chore: address review feedback for grounding metadata extraction
- Applied formatting using pyink and isort - Refactored grounding metadata unit tests to use a helper for mock message creation
1 parent 420f991 commit 3d43c6e

3 files changed

Lines changed: 49 additions & 119 deletions

File tree

contributing/samples/gepa/experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from tau_bench.types import EnvRunResult
4444
from tau_bench.types import RunConfig
4545
import tau_bench_agent as tau_bench_agent_lib
46-
4746
import utils
4847

4948

contributing/samples/gepa/run_experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from absl import flags
2626
import experiment
2727
from google.genai import types
28-
2928
import utils
3029

3130
_OUTPUT_DIR = flags.DEFINE_string(

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 49 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -783,11 +783,36 @@ async def test_send_history_filters_various_audio_mime_types(
783783
mock_gemini_session.send.assert_not_called()
784784

785785

786+
def _create_mock_receive_message(
787+
model_turn=None,
788+
grounding_metadata=None,
789+
interrupted=False,
790+
turn_complete=False,
791+
tool_call=None,
792+
):
793+
"""Helper to create a mock message from the Gemini API."""
794+
mock_server_content = mock.Mock()
795+
mock_server_content.model_turn = model_turn
796+
mock_server_content.interrupted = interrupted
797+
mock_server_content.input_transcription = None
798+
mock_server_content.output_transcription = None
799+
mock_server_content.turn_complete = turn_complete
800+
mock_server_content.generation_complete = False
801+
mock_server_content.grounding_metadata = grounding_metadata
802+
803+
mock_message = mock.Mock()
804+
mock_message.usage_metadata = None
805+
mock_message.server_content = mock_server_content
806+
mock_message.tool_call = tool_call
807+
mock_message.session_resumption_update = None
808+
return mock_message
809+
810+
786811
@pytest.mark.asyncio
787812
async def test_receive_extracts_grounding_metadata(
788813
gemini_connection, mock_gemini_session
789814
):
790-
"""Test that grounding_metadata is extracted from server_content and included in LlmResponse."""
815+
"""Test that grounding_metadata is extracted and included in LlmResponse."""
791816
mock_content = types.Content(
792817
role='model', parts=[types.Part.from_text(text='response text')]
793818
)
@@ -796,20 +821,11 @@ async def test_receive_extracts_grounding_metadata(
796821
web_search_queries=['web search query'],
797822
)
798823

799-
mock_server_content = mock.Mock()
800-
mock_server_content.model_turn = mock_content
801-
mock_server_content.interrupted = False
802-
mock_server_content.input_transcription = None
803-
mock_server_content.output_transcription = None
804-
mock_server_content.turn_complete = True
805-
mock_server_content.generation_complete = False
806-
mock_server_content.grounding_metadata = mock_grounding_metadata
807-
808-
mock_message = mock.Mock()
809-
mock_message.usage_metadata = None
810-
mock_message.server_content = mock_server_content
811-
mock_message.tool_call = None
812-
mock_message.session_resumption_update = None
824+
mock_message = _create_mock_receive_message(
825+
model_turn=mock_content,
826+
grounding_metadata=mock_grounding_metadata,
827+
turn_complete=True,
828+
)
813829

814830
async def mock_receive_generator():
815831
yield mock_message
@@ -842,36 +858,12 @@ async def test_receive_grounding_metadata_at_turn_complete(
842858
)
843859

844860
# First message with grounding but no content
845-
mock_server_content1 = mock.Mock()
846-
mock_server_content1.model_turn = None
847-
mock_server_content1.interrupted = False
848-
mock_server_content1.input_transcription = None
849-
mock_server_content1.output_transcription = None
850-
mock_server_content1.turn_complete = False
851-
mock_server_content1.generation_complete = False
852-
mock_server_content1.grounding_metadata = mock_grounding_metadata
853-
854-
message1 = mock.Mock()
855-
message1.usage_metadata = None
856-
message1.server_content = mock_server_content1
857-
message1.tool_call = None
858-
message1.session_resumption_update = None
861+
message1 = _create_mock_receive_message(
862+
grounding_metadata=mock_grounding_metadata
863+
)
859864

860865
# Second message with turn_complete
861-
mock_server_content2 = mock.Mock()
862-
mock_server_content2.model_turn = None
863-
mock_server_content2.interrupted = False
864-
mock_server_content2.input_transcription = None
865-
mock_server_content2.output_transcription = None
866-
mock_server_content2.turn_complete = True
867-
mock_server_content2.generation_complete = False
868-
mock_server_content2.grounding_metadata = None
869-
870-
message2 = mock.Mock()
871-
message2.usage_metadata = None
872-
message2.server_content = mock_server_content2
873-
message2.tool_call = None
874-
message2.session_resumption_update = None
866+
message2 = _create_mock_receive_message(turn_complete=True)
875867

876868
async def mock_receive_generator():
877869
yield message1
@@ -902,20 +894,11 @@ async def test_receive_grounding_metadata_with_text_and_turn_complete(
902894
)
903895

904896
# Message with both content and grounding, followed by turn_complete
905-
mock_server_content = mock.Mock()
906-
mock_server_content.model_turn = mock_content
907-
mock_server_content.interrupted = False
908-
mock_server_content.input_transcription = None
909-
mock_server_content.output_transcription = None
910-
mock_server_content.turn_complete = True
911-
mock_server_content.generation_complete = False
912-
mock_server_content.grounding_metadata = mock_grounding_metadata
913-
914-
mock_message = mock.Mock()
915-
mock_message.usage_metadata = None
916-
mock_message.server_content = mock_server_content
917-
mock_message.tool_call = None
918-
mock_message.session_resumption_update = None
897+
mock_message = _create_mock_receive_message(
898+
model_turn=mock_content,
899+
grounding_metadata=mock_grounding_metadata,
900+
turn_complete=True,
901+
)
919902

920903
async def mock_receive_generator():
921904
yield mock_message
@@ -946,20 +929,9 @@ async def test_receive_grounding_metadata_with_tool_call(
946929
)
947930

948931
# First message with grounding metadata
949-
mock_server_content1 = mock.Mock()
950-
mock_server_content1.model_turn = None
951-
mock_server_content1.interrupted = False
952-
mock_server_content1.input_transcription = None
953-
mock_server_content1.output_transcription = None
954-
mock_server_content1.turn_complete = False
955-
mock_server_content1.generation_complete = False
956-
mock_server_content1.grounding_metadata = mock_grounding_metadata
957-
958-
message1 = mock.Mock()
959-
message1.usage_metadata = None
960-
message1.server_content = mock_server_content1
961-
message1.tool_call = None
962-
message1.session_resumption_update = None
932+
message1 = _create_mock_receive_message(
933+
grounding_metadata=mock_grounding_metadata
934+
)
963935

964936
# Second message with tool_call
965937
mock_function_call = types.FunctionCall(
@@ -968,11 +940,8 @@ async def test_receive_grounding_metadata_with_tool_call(
968940
mock_tool_call = mock.Mock()
969941
mock_tool_call.function_calls = [mock_function_call]
970942

971-
message2 = mock.Mock()
972-
message2.usage_metadata = None
943+
message2 = _create_mock_receive_message(tool_call=mock_tool_call)
973944
message2.server_content = None
974-
message2.tool_call = mock_tool_call
975-
message2.session_resumption_update = None
976945

977946
async def mock_receive_generator():
978947
yield message1
@@ -1006,55 +975,18 @@ async def test_receive_interrupted_with_pending_text_preserves_flag(
1006975
mock_content1 = types.Content(
1007976
role='model', parts=[types.Part.from_text(text='partial')]
1008977
)
1009-
mock_server_content1 = mock.Mock()
1010-
mock_server_content1.model_turn = mock_content1
1011-
mock_server_content1.interrupted = False
1012-
mock_server_content1.input_transcription = None
1013-
mock_server_content1.output_transcription = None
1014-
mock_server_content1.turn_complete = False
1015-
mock_server_content1.generation_complete = False
1016-
mock_server_content1.grounding_metadata = mock_grounding_metadata
1017-
1018-
message1 = mock.Mock()
1019-
message1.usage_metadata = None
1020-
message1.server_content = mock_server_content1
1021-
message1.tool_call = None
1022-
message1.session_resumption_update = None
978+
message1 = _create_mock_receive_message(
979+
model_turn=mock_content1, grounding_metadata=mock_grounding_metadata
980+
)
1023981

1024982
# Second message with more text
1025983
mock_content2 = types.Content(
1026984
role='model', parts=[types.Part.from_text(text=' text')]
1027985
)
1028-
mock_server_content2 = mock.Mock()
1029-
mock_server_content2.model_turn = mock_content2
1030-
mock_server_content2.interrupted = False
1031-
mock_server_content2.input_transcription = None
1032-
mock_server_content2.output_transcription = None
1033-
mock_server_content2.turn_complete = False
1034-
mock_server_content2.generation_complete = False
1035-
mock_server_content2.grounding_metadata = None
1036-
1037-
message2 = mock.Mock()
1038-
message2.usage_metadata = None
1039-
message2.server_content = mock_server_content2
1040-
message2.tool_call = None
1041-
message2.session_resumption_update = None
986+
message2 = _create_mock_receive_message(model_turn=mock_content2)
1042987

1043988
# Third message with interrupted signal
1044-
mock_server_content3 = mock.Mock()
1045-
mock_server_content3.model_turn = None
1046-
mock_server_content3.interrupted = True
1047-
mock_server_content3.input_transcription = None
1048-
mock_server_content3.output_transcription = None
1049-
mock_server_content3.turn_complete = False
1050-
mock_server_content3.generation_complete = False
1051-
mock_server_content3.grounding_metadata = None
1052-
1053-
message3 = mock.Mock()
1054-
message3.usage_metadata = None
1055-
message3.server_content = mock_server_content3
1056-
message3.tool_call = None
1057-
message3.session_resumption_update = None
989+
message3 = _create_mock_receive_message(interrupted=True)
1058990

1059991
async def mock_receive_generator():
1060992
yield message1

0 commit comments

Comments
 (0)