Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,32 +304,46 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
# generation_complete, causing transcription to appear after
# tool_call in the session log.
if message.server_content.input_transcription:
if message.server_content.input_transcription.text:
self._input_transcription_text += (
message.server_content.input_transcription.text
)
yield LlmResponse(
input_transcription=types.Transcription(
text=message.server_content.input_transcription.text,
finished=False,
),
partial=True,
model_version=self._model_version,
live_session_id=live_session_id,
)
# finished=True and partial transcription may happen in the same
# message.
if message.server_content.input_transcription.finished:
yield LlmResponse(
input_transcription=types.Transcription(
text=self._input_transcription_text,
finished=True,
),
partial=False,
model_version=self._model_version,
live_session_id=live_session_id,
)
self._input_transcription_text = ''
# Gemini 3.1 Flash Live only sends a single final input
# transcription
if self._is_gemini_3_1_flash_live:
if message.server_content.input_transcription.text:
yield LlmResponse(
input_transcription=types.Transcription(
text=message.server_content.input_transcription.text,
finished=True,
),
partial=False,
model_version=self._model_version,
live_session_id=live_session_id,
)
else:
if message.server_content.input_transcription.text:
self._input_transcription_text += (
message.server_content.input_transcription.text
)
yield LlmResponse(
input_transcription=types.Transcription(
text=message.server_content.input_transcription.text,
finished=False,
),
partial=True,
model_version=self._model_version,
live_session_id=live_session_id,
)
# finished=True and partial transcription may happen in the same
# message.
if message.server_content.input_transcription.finished:
yield LlmResponse(
input_transcription=types.Transcription(
text=self._input_transcription_text,
finished=True,
),
partial=False,
model_version=self._model_version,
live_session_id=live_session_id,
)
self._input_transcription_text = ''
if message.server_content.output_transcription:
if message.server_content.output_transcription.text:
self._output_transcription_text += (
Expand Down
66 changes: 66 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,3 +1676,69 @@ async def mock_receive_generator():
assert responses[1].partial is False
assert responses[2].turn_complete is True
assert responses[2].grounding_metadata is None


@pytest.mark.asyncio
async def test_receive_input_transcription_gemini_3_1(
mock_gemini_session,
):
"""Verify input_transcription yields finished=True immediately for Gemini 3.1."""
conn = GeminiLlmConnection(
mock_gemini_session,
model_version='gemini-3.1-flash-live-preview',
)

def make_msg(
input_text=None, output_text=None, output_finished=False, tc=False
):
msg = mock.create_autospec(types.LiveServerMessage, instance=True)
msg.usage_metadata = None
msg.tool_call = None
msg.session_resumption_update = None
msg.go_away = None
msg.server_content = mock.Mock()
msg.server_content.interrupted = False
msg.server_content.input_transcription = (
types.Transcription(text=input_text, finished=False)
if input_text
else None
)
msg.server_content.output_transcription = (
types.Transcription(text=output_text, finished=output_finished)
if output_text
else None
)
msg.server_content.generation_complete = False
msg.server_content.turn_complete = tc
msg.server_content.grounding_metadata = None
msg.server_content.model_turn = None
return msg

msg1 = make_msg(input_text='Hello')
msg2 = make_msg(output_text='Hi there!', output_finished=True)
msg3 = make_msg(tc=True)

async def mock_receive_generator():
yield msg1
yield msg2
yield msg3

mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())

responses = [resp async for resp in conn.receive()]

assert len(responses) == 4

assert responses[0].input_transcription.text == 'Hello'
assert responses[0].input_transcription.finished is True
assert responses[0].partial is False

assert responses[1].output_transcription.text == 'Hi there!'
assert responses[1].output_transcription.finished is False
assert responses[1].partial is True

assert responses[2].output_transcription.text == 'Hi there!'
assert responses[2].output_transcription.finished is True
assert responses[2].partial is False

assert responses[3].turn_complete is True
Loading