Skip to content

Commit 459782f

Browse files
wukathsasha-gitg
andauthored
Cherry pick fixes to v1 (#5934)
Co-authored-by: asobran <asobran@google.com>
1 parent 7c0e186 commit 459782f

6 files changed

Lines changed: 591 additions & 12 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from ...tools.base_toolset import BaseToolset
4949
from ...tools.tool_context import ToolContext
5050
from ...utils.context_utils import Aclosing
51+
from ...utils import model_name_utils
5152
from .audio_cache_manager import AudioCacheManager
5253
from .functions import build_auth_request_event
5354

@@ -516,6 +517,20 @@ async def run_live(
516517
)
517518
llm_request.live_connect_config.session_resumption.transparent = True
518519

520+
if (
521+
isinstance(llm, Gemini)
522+
and llm._api_backend == GoogleLLMVariant.GEMINI_API
523+
and model_name_utils.is_gemini_3_1_flash_live(llm_request.model)
524+
and llm_request.contents
525+
and not invocation_context.live_session_resumption_handle
526+
):
527+
if llm_request.live_connect_config is None:
528+
llm_request.live_connect_config = types.LiveConnectConfig()
529+
if llm_request.live_connect_config.history_config is None:
530+
llm_request.live_connect_config.history_config = types.HistoryConfig(
531+
initial_history_in_client_content=True
532+
)
533+
519534
logger.info(
520535
'Establishing live connection for agent: %s',
521536
invocation_context.agent.name,

src/google/adk/flows/llm_flows/basic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...agents.invocation_context import InvocationContext
2626
from ...events.event import Event
2727
from ...models.llm_request import LlmRequest
28+
from ...utils import model_name_utils
2829
from ...utils.output_schema_utils import can_use_output_schema_with_tools
2930
from ._base_llm_processor import BaseLlmRequestProcessor
3031

@@ -78,11 +79,13 @@ def _build_basic_request(
7879
llm_request.live_connect_config.realtime_input_config = (
7980
invocation_context.run_config.realtime_input_config
8081
)
82+
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
83+
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
8184
llm_request.live_connect_config.enable_affective_dialog = (
82-
invocation_context.run_config.enable_affective_dialog
85+
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
8386
)
8487
llm_request.live_connect_config.proactivity = (
85-
invocation_context.run_config.proactivity
88+
None if is_gemini_31 else invocation_context.run_config.proactivity
8689
)
8790
llm_request.live_connect_config.session_resumption = (
8891
invocation_context.run_config.session_resumption

src/google/adk/models/gemini_llm_connection.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,24 @@ async def send_history(self, history: list[types.Content]):
8080
]
8181

8282
if contents:
83+
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
84+
self._model_version
85+
)
86+
# Gemini Enterprise Agent Platform does not support history_config in the SDK.
87+
# To initialize a live session with prior history without hitting a 1007
88+
# protocol error (invalid role mid-session), we consolidate previous multi-turn
89+
# interactions into a unified contextual preamble on a single user role turn.
90+
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
91+
collapsed_text = "Previous conversation history:\n"
92+
for c in contents:
93+
text_parts = "".join(p.text for p in c.parts if p.text)
94+
collapsed_text += f'[{c.role}]: {text_parts}\n'
95+
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]
96+
8397
logger.debug('Sending history to live connection: %s', contents)
8498
await self._gemini_session.send_client_content(
8599
turns=contents,
86-
turn_complete=contents[-1].role == 'user',
100+
turn_complete=True if is_gemini_31 else (contents[-1].role == 'user'),
87101
)
88102
else:
89103
logger.info('no content is sent')
@@ -159,14 +173,21 @@ async def send_realtime(self, input: RealtimeInput):
159173
else:
160174
raise ValueError('Unsupported input type: %s' % type(input))
161175

162-
def __build_full_text_response(self, text: str):
176+
def __build_full_text_response(
177+
self,
178+
text: str,
179+
is_thought: bool = False,
180+
grounding_metadata: types.GroundingMetadata | None = None,
181+
):
163182
"""Builds a full text response.
164183
165184
The text should not be partial and the returned LlmResponse is not
166185
partial.
167186
168187
Args:
169188
text: The text to be included in the response.
189+
is_thought: Whether the text is a thought.
190+
grounding_metadata: The grounding metadata to include.
170191
171192
Returns:
172193
An LlmResponse containing the full text.
@@ -176,6 +197,8 @@ def __build_full_text_response(self, text: str):
176197
role='model',
177198
parts=[types.Part.from_text(text=text)],
178199
),
200+
grounding_metadata=grounding_metadata,
201+
partial=False,
179202
live_session_id=self._gemini_session.session_id,
180203
)
181204

@@ -188,6 +211,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
188211

189212
text = ''
190213
tool_call_parts = []
214+
pending_grounding_metadata = None
191215
async with Aclosing(self._gemini_session.receive()) as agen:
192216
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
193217
# partial content and emit responses as needed.
@@ -203,6 +227,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
203227
)
204228
if message.server_content:
205229
content = message.server_content.model_turn
230+
if message.server_content.grounding_metadata:
231+
pending_grounding_metadata = (
232+
message.server_content.grounding_metadata
233+
)
206234

207235
# Standalone grounding_metadata event (when content is empty)
208236
if (
@@ -215,6 +243,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
215243
interrupted=message.server_content.interrupted,
216244
model_version=self._model_version,
217245
live_session_id=live_session_id,
246+
turn_complete_reason=getattr(
247+
message.server_content, 'turn_complete_reason', None
248+
),
218249
)
219250

220251
if content and content.parts:
@@ -223,19 +254,31 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
223254
interrupted=message.server_content.interrupted,
224255
model_version=self._model_version,
225256
live_session_id=live_session_id,
257+
turn_complete_reason=getattr(
258+
message.server_content, 'turn_complete_reason', None
259+
),
226260
)
227261
# grounding_metadata is yielded again at turn_complete,
228262
# so avoid duplicating it here if turn_complete is true.
229263
if not message.server_content.turn_complete:
230264
llm_response.grounding_metadata = (
231265
message.server_content.grounding_metadata
232266
)
233-
if content.parts[0].text:
234-
text += content.parts[0].text
235-
llm_response.partial = True
267+
has_inline_data = any(p.inline_data for p in content.parts)
268+
for part in content.parts:
269+
if part.text:
270+
current_is_thought = getattr(part, 'thought', False)
271+
if text and current_is_thought != is_thought:
272+
yield self.__build_full_text_response(text, is_thought)
273+
text = ''
274+
is_thought = False
275+
276+
text += part.text
277+
is_thought = current_is_thought
278+
llm_response.partial = True
236279
# don't yield the merged text event when receiving audio data
237-
elif text and not content.parts[0].inline_data:
238-
yield self.__build_full_text_response(text)
280+
if text and not any(p.text for p in content.parts) and not has_inline_data:
281+
yield self.__build_full_text_response(text, is_thought)
239282
text = ''
240283
yield llm_response
241284
# Note: in some cases, tool_call may arrive before
@@ -324,9 +367,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
324367
)
325368
self._output_transcription_text = ''
326369
if message.server_content.turn_complete:
370+
g_metadata_to_yield = pending_grounding_metadata
327371
if text:
328-
yield self.__build_full_text_response(text)
372+
yield self.__build_full_text_response(
373+
text, is_thought, g_metadata_to_yield
374+
)
329375
text = ''
376+
is_thought = False
377+
g_metadata_to_yield = None
330378
if tool_call_parts:
331379
logger.debug('Returning aggregated tool_call_parts')
332380
yield LlmResponse(
@@ -338,9 +386,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
338386
yield LlmResponse(
339387
turn_complete=True,
340388
interrupted=message.server_content.interrupted,
341-
grounding_metadata=message.server_content.grounding_metadata,
389+
grounding_metadata=message.server_content.grounding_metadata
390+
or g_metadata_to_yield,
342391
model_version=self._model_version,
343392
live_session_id=live_session_id,
393+
turn_complete_reason=getattr(
394+
message.server_content, 'turn_complete_reason', None
395+
),
344396
)
345397
break
346398
# in case of empty content or parts, we still surface it

src/google/adk/models/llm_response.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class LlmResponse(BaseModel):
8181
Only used for streaming mode.
8282
"""
8383

84+
turn_complete_reason: Optional[types.TurnCompleteReason] = None
85+
"""The reason why the turn is complete.
86+
87+
Only used for streaming mode.
88+
"""
89+
8490
finish_reason: Optional[types.FinishReason] = None
8591
"""The finish reason of the response."""
8692

0 commit comments

Comments
 (0)