diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index ac797d93e7..ff2f81e657 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -304,32 +304,46 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # generation_complete, causing transcription to appear after # tool_call in the session log. if message.server_content.input_transcription: - if message.server_content.input_transcription.text: - self._input_transcription_text += ( - message.server_content.input_transcription.text - ) - yield LlmResponse( - input_transcription=types.Transcription( - text=message.server_content.input_transcription.text, - finished=False, - ), - partial=True, - model_version=self._model_version, - live_session_id=live_session_id, - ) - # finished=True and partial transcription may happen in the same - # message. - if message.server_content.input_transcription.finished: - yield LlmResponse( - input_transcription=types.Transcription( - text=self._input_transcription_text, - finished=True, - ), - partial=False, - model_version=self._model_version, - live_session_id=live_session_id, - ) - self._input_transcription_text = '' + # Gemini 3.1 Flash Live only sends a single final input + # transcription + if self._is_gemini_3_1_flash_live: + if message.server_content.input_transcription.text: + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + else: + if message.server_content.input_transcription.text: + self._input_transcription_text += ( + message.server_content.input_transcription.text + ) + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=False, + ), + partial=True, + model_version=self._model_version, + live_session_id=live_session_id, + ) + # finished=True and partial transcription may happen in the same + # message. + if message.server_content.input_transcription.finished: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + self._input_transcription_text = '' if message.server_content.output_transcription: if message.server_content.output_transcription.text: self._output_transcription_text += ( diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 47154306a2..62548dac30 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1676,3 +1676,69 @@ async def mock_receive_generator(): assert responses[1].partial is False assert responses[2].turn_complete is True assert responses[2].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_input_transcription_gemini_3_1( + mock_gemini_session, +): + """Verify input_transcription yields finished=True immediately for Gemini 3.1.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-3.1-flash-live-preview', + ) + + def make_msg( + input_text=None, output_text=None, output_finished=False, tc=False + ): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = None + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = ( + types.Transcription(text=input_text, finished=False) + if input_text + else None + ) + msg.server_content.output_transcription = ( + types.Transcription(text=output_text, finished=output_finished) + if output_text + else None + ) + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = None + return msg + + msg1 = make_msg(input_text='Hello') + msg2 = make_msg(output_text='Hi there!', output_finished=True) + msg3 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in conn.receive()] + + assert len(responses) == 4 + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is True + assert responses[0].partial is False + + assert responses[1].output_transcription.text == 'Hi there!' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + + assert responses[2].output_transcription.text == 'Hi there!' + assert responses[2].output_transcription.finished is True + assert responses[2].partial is False + + assert responses[3].turn_complete is True