Skip to content

Commit 9126acb

Browse files
committed
feat(models): Support turn_complete_reason in Live responses to capture safety info
Allows clients to capture the safety filter and turn completion reasons when using the Gemini Live API (v3.1+). Previously in Live v2.5, SafetyFilterMetadata was passed back, but with Live v3.1 it has been moved to turn_complete_reason. This change adds support for turn_complete_reason in LlmResponse and populates it from the server content. Change-Id: Ie6d970e65f3937fb1cfa3bf7f01df79f4f1a71d8
1 parent aa51512 commit 9126acb

3 files changed

Lines changed: 145 additions & 0 deletions

File tree

src/google/adk/models/gemini_llm_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
233233
interrupted=message.server_content.interrupted,
234234
model_version=self._model_version,
235235
live_session_id=live_session_id,
236+
turn_complete_reason=getattr(
237+
message.server_content, 'turn_complete_reason', None
238+
),
236239
)
237240

238241
if content and content.parts:
@@ -241,6 +244,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
241244
interrupted=message.server_content.interrupted,
242245
model_version=self._model_version,
243246
live_session_id=live_session_id,
247+
turn_complete_reason=getattr(
248+
message.server_content, 'turn_complete_reason', None
249+
),
244250
)
245251
# grounding_metadata is yielded again at turn_complete,
246252
# so avoid duplicating it here if turn_complete is true.
@@ -373,6 +379,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
373379
or g_metadata_to_yield,
374380
model_version=self._model_version,
375381
live_session_id=live_session_id,
382+
turn_complete_reason=getattr(
383+
message.server_content, 'turn_complete_reason', None
384+
),
376385
)
377386
break
378387
# in case of empty content or parts, we still surface it

src/google/adk/models/llm_response.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class LlmResponse(BaseModel):
8181
Only used for streaming mode.
8282
"""
8383

84+
turn_complete_reason: Optional[types.TurnCompleteReason] = None
85+
"""The reason why the turn is complete.
86+
87+
Only used for streaming mode.
88+
"""
89+
8490
finish_reason: Optional[types.FinishReason] = None
8591
"""The finish reason of the response."""
8692

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,3 +1410,133 @@ async def gen():
14101410

14111411
assert responses[3].turn_complete is True
14121412
assert responses[3].grounding_metadata is None
1413+
1414+
1415+
@pytest.mark.asyncio
1416+
async def test_receive_populates_turn_complete_reason(
1417+
gemini_connection, mock_gemini_session
1418+
):
1419+
"""Test that receive populates turn_complete_reason in LlmResponse."""
1420+
mock_server_content = mock.create_autospec(
1421+
types.LiveServerContent, instance=True
1422+
)
1423+
mock_server_content.model_turn = None
1424+
mock_server_content.grounding_metadata = None
1425+
mock_server_content.turn_complete = True
1426+
mock_server_content.interrupted = False
1427+
mock_server_content.input_transcription = None
1428+
mock_server_content.output_transcription = None
1429+
mock_server_content.generation_complete = False
1430+
mock_server_content.turn_complete_reason = (
1431+
types.TurnCompleteReason.RESPONSE_REJECTED
1432+
)
1433+
1434+
mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
1435+
mock_message.usage_metadata = None
1436+
mock_message.server_content = mock_server_content
1437+
mock_message.tool_call = None
1438+
mock_message.session_resumption_update = None
1439+
mock_message.go_away = None
1440+
1441+
async def mock_receive_generator():
1442+
yield mock_message
1443+
1444+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
1445+
1446+
responses = [resp async for resp in gemini_connection.receive()]
1447+
1448+
assert len(responses) == 1
1449+
assert responses[0].turn_complete is True
1450+
assert (
1451+
responses[0].turn_complete_reason
1452+
== types.TurnCompleteReason.RESPONSE_REJECTED
1453+
)
1454+
1455+
1456+
@pytest.mark.asyncio
1457+
async def test_receive_populates_turn_complete_reason_standalone_grounding(
1458+
gemini_connection, mock_gemini_session
1459+
):
1460+
"""Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata."""
1461+
mock_server_content = mock.create_autospec(
1462+
types.LiveServerContent, instance=True
1463+
)
1464+
mock_server_content.model_turn = None
1465+
mock_server_content.grounding_metadata = mock.create_autospec(
1466+
types.GroundingMetadata, instance=True
1467+
)
1468+
mock_server_content.turn_complete = False
1469+
mock_server_content.interrupted = False
1470+
mock_server_content.input_transcription = None
1471+
mock_server_content.output_transcription = None
1472+
mock_server_content.generation_complete = False
1473+
mock_server_content.turn_complete_reason = (
1474+
types.TurnCompleteReason.RESPONSE_REJECTED
1475+
)
1476+
1477+
mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
1478+
mock_message.usage_metadata = None
1479+
mock_message.server_content = mock_server_content
1480+
mock_message.tool_call = None
1481+
mock_message.session_resumption_update = None
1482+
mock_message.go_away = None
1483+
1484+
async def mock_receive_generator():
1485+
yield mock_message
1486+
1487+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
1488+
1489+
responses = [resp async for resp in gemini_connection.receive()]
1490+
1491+
assert len(responses) == 1
1492+
assert responses[0].grounding_metadata is not None
1493+
assert responses[0].turn_complete is None
1494+
assert (
1495+
responses[0].turn_complete_reason
1496+
== types.TurnCompleteReason.RESPONSE_REJECTED
1497+
)
1498+
1499+
1500+
@pytest.mark.asyncio
1501+
async def test_receive_populates_turn_complete_reason_with_content(
1502+
gemini_connection, mock_gemini_session
1503+
):
1504+
"""Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts."""
1505+
mock_content = types.Content(
1506+
role='model',
1507+
parts=[types.Part.from_text(text='hello')],
1508+
)
1509+
mock_server_content = mock.create_autospec(
1510+
types.LiveServerContent, instance=True
1511+
)
1512+
mock_server_content.model_turn = mock_content
1513+
mock_server_content.grounding_metadata = None
1514+
mock_server_content.turn_complete = False
1515+
mock_server_content.interrupted = False
1516+
mock_server_content.input_transcription = None
1517+
mock_server_content.output_transcription = None
1518+
mock_server_content.generation_complete = False
1519+
mock_server_content.turn_complete_reason = (
1520+
types.TurnCompleteReason.RESPONSE_REJECTED
1521+
)
1522+
1523+
mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
1524+
mock_message.usage_metadata = None
1525+
mock_message.server_content = mock_server_content
1526+
mock_message.tool_call = None
1527+
mock_message.session_resumption_update = None
1528+
mock_message.go_away = None
1529+
1530+
async def mock_receive_generator():
1531+
yield mock_message
1532+
1533+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
1534+
1535+
responses = [resp async for resp in gemini_connection.receive()]
1536+
1537+
assert len(responses) == 1
1538+
assert responses[0].content == mock_content
1539+
assert (
1540+
responses[0].turn_complete_reason
1541+
== types.TurnCompleteReason.RESPONSE_REJECTED
1542+
)

0 commit comments

Comments
 (0)