Skip to content

Commit 4dde4c1

Browse files
committed
fix: harden SetModelResponseTool fallback to prevent infinite loops
Flash models (gemini-2.5-flash, gemini-3-flash) can ignore set_model_response and loop indefinitely when output_schema is used with tools. This adds a layered defense: 1. Type-aware instruction: primitive schemas (str, int) get a stronger prompt since their trivial tool signature is easily ignored by flash models. 2. Deterministic tool_choice guard: on round N-1 (_MAX_TOOL_ROUNDS-1), restrict the model to only call set_model_response via tool_config. 3. Hard cutoff: on round N, terminate the invocation entirely to prevent runaway API costs. 4. Early return after set_model_response: skip unnecessary transfer_to_agent processing in base_llm_flow.py after structured output is successfully produced. Based on analysis by @surfai, @nino-robotfutures-co, and @surajksharma07 on #5054.
1 parent 1104523 commit 4dde4c1

File tree

4 files changed

+436
-10
lines changed

4 files changed

+436
-10
lines changed

src/google/adk/flows/llm_flows/_output_schema_processor.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,25 @@
1717
from __future__ import annotations
1818

1919
import json
20+
import logging
2021
from typing import AsyncGenerator
2122

23+
from google.genai import types
2224
from typing_extensions import override
2325

2426
from ...agents.invocation_context import InvocationContext
2527
from ...events.event import Event
2628
from ...models.llm_request import LlmRequest
2729
from ...tools.set_model_response_tool import SetModelResponseTool
30+
from ...utils._schema_utils import is_basemodel_schema
2831
from ...utils.output_schema_utils import can_use_output_schema_with_tools
2932
from ._base_llm_processor import BaseLlmRequestProcessor
3033

34+
logger = logging.getLogger('google_adk.' + __name__)
35+
36+
# Max tool rounds before forcing set_model_response (N-1) or terminating (N).
37+
_MAX_TOOL_ROUNDS = 25
38+
3139

3240
class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
3341
"""Processor that handles output schema for agents with tools."""
@@ -36,8 +44,6 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
3644
async def run_async(
3745
self, invocation_context: InvocationContext, llm_request: LlmRequest
3846
) -> AsyncGenerator[Event, None]:
39-
from ...agents.llm_agent import LlmAgent
40-
4147
agent = invocation_context.agent
4248

4349
# Check if we need the processor: output_schema + tools + cannot use output
@@ -49,20 +55,56 @@ async def run_async(
4955
):
5056
return
5157

58+
# Count how many tool rounds have occurred in this invocation.
59+
tool_rounds = sum(
60+
1
61+
for e in invocation_context._get_events(
62+
current_invocation=True, current_branch=True
63+
)
64+
if e.get_function_responses()
65+
)
66+
67+
# Terminate the invocation if the model never calls set_model_response.
68+
if tool_rounds >= _MAX_TOOL_ROUNDS:
69+
logger.error(
70+
'Tool execution reached %d rounds without producing structured'
71+
' output via set_model_response. Breaking loop to prevent'
72+
' runaway API costs.',
73+
tool_rounds,
74+
)
75+
invocation_context.end_invocation = True
76+
return
77+
5278
# Add the set_model_response tool to handle structured output
5379
set_response_tool = SetModelResponseTool(agent.output_schema)
5480
llm_request.append_tools([set_response_tool])
5581

56-
# Add instruction about using the set_model_response tool
57-
instruction = (
58-
'IMPORTANT: You have access to other tools, but you must provide '
59-
'your final response using the set_model_response tool with the '
60-
'required structured format. After using any other tools needed '
61-
'to complete the task, always call set_model_response with your '
62-
'final answer in the specified schema format.'
63-
)
82+
# Primitive types (str, int, etc.) produce a trivial tool signature
83+
# that flash models tend to ignore use a stronger instruction.
84+
if is_basemodel_schema(agent.output_schema):
85+
instruction = (
86+
'After completing any needed tool calls, provide your final'
87+
' response by calling set_model_response with the required'
88+
' fields.'
89+
)
90+
else:
91+
instruction = (
92+
'IMPORTANT: After using any needed tools, you MUST call'
93+
' set_model_response to provide your final answer.'
94+
' This is required to complete the task.'
95+
)
6496
llm_request.append_instructions([instruction])
6597

98+
# On round N-1, restrict the model to only call set_model_response.
99+
if tool_rounds >= _MAX_TOOL_ROUNDS - 1:
100+
llm_request.config = llm_request.config or types.GenerateContentConfig()
101+
llm_request.config.tool_config = types.ToolConfig(
102+
function_calling_config=types.FunctionCallingConfig(
103+
mode=types.FunctionCallingConfigMode.ANY,
104+
allowed_function_names=['set_model_response'],
105+
)
106+
)
107+
66108
return
67109
yield # Generator requires yield statement in function body.
68110

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ async def _postprocess_live(
10451045
)
10461046
)
10471047
yield final_event
1048+
return # Skip further processing after set_model_response.
10481049

10491050
async def _postprocess_run_processors_async(
10501051
self, invocation_context: InvocationContext, llm_response: LlmResponse
@@ -1091,6 +1092,7 @@ async def _postprocess_handle_function_calls_async(
10911092
)
10921093
)
10931094
yield final_event
1095+
return # Skip transfer_to_agent after set_model_response.
10941096
transfer_to_agent = function_response_event.actions.transfer_to_agent
10951097
if transfer_to_agent:
10961098
agent_to_run = self._get_agent_to_run(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Integration test for output_schema + tools behavior.
2+
3+
Requires GOOGLE_API_KEY or Vertex AI credentials.
4+
Run with: python -m pytest tests/integration/test_output_schema_with_tools.py -v -s
5+
"""
6+
7+
import os
8+
import time
9+
10+
from google.adk.agents.llm_agent import LlmAgent
11+
from google.adk.runners import Runner
12+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
13+
from google.genai import types
14+
from pydantic import BaseModel
15+
from pydantic import Field
16+
import pytest
17+
18+
19+
class AnalysisResult(BaseModel):
20+
summary: str = Field(description='Brief summary of the analysis')
21+
confidence: float = Field(description='Confidence score between 0 and 1')
22+
23+
24+
def search_data(query: str) -> str:
25+
"""Search for data based on the query."""
26+
return f'Found data for: {query}. Revenue is $1M, growth is 15%.'
27+
28+
29+
def calculate_metric(metric_name: str, value: float) -> str:
30+
"""Calculate a business metric."""
31+
return f'{metric_name}: {value * 1.1:.2f} (adjusted)'
32+
33+
34+
# Skip if no API key is configured.
35+
skip_no_api_key = pytest.mark.skipif(
36+
not os.environ.get('GOOGLE_API_KEY')
37+
and not os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'),
38+
reason='No Gemini API key or Vertex AI configured',
39+
)
40+
41+
42+
@skip_no_api_key
43+
@pytest.mark.asyncio
44+
async def test_basemodel_schema_with_tools():
45+
"""Test that BaseModel output_schema + tools produces structured output."""
46+
agent = LlmAgent(
47+
name='analyst',
48+
model='gemini-2.5-flash',
49+
instruction=(
50+
'Analyze the query using the available tools, then return'
51+
' structured output.'
52+
),
53+
output_schema=AnalysisResult,
54+
tools=[search_data, calculate_metric],
55+
)
56+
57+
session_service = InMemorySessionService()
58+
runner = Runner(
59+
agent=agent, app_name='test_app', session_service=session_service
60+
)
61+
session = await session_service.create_session(
62+
app_name='test_app', user_id='test_user'
63+
)
64+
65+
events = []
66+
start = time.time()
67+
68+
async for event in runner.run_async(
69+
user_id='test_user',
70+
session_id=session.id,
71+
new_message=types.Content(
72+
role='user',
73+
parts=[types.Part(text='Analyze Q1 revenue performance')],
74+
),
75+
):
76+
events.append(event)
77+
78+
elapsed = time.time() - start
79+
80+
# Should complete within a reasonable time (not infinite loop).
81+
assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop'
82+
83+
# Should have at least one event with structured output.
84+
final_texts = [
85+
e.content.parts[0].text
86+
for e in events
87+
if e.content and e.content.parts and e.content.parts[0].text
88+
]
89+
assert len(final_texts) > 0, 'No text output produced'
90+
print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events')
91+
print(f'Final output: {final_texts[-1][:200]}')
92+
93+
94+
@skip_no_api_key
95+
@pytest.mark.asyncio
96+
async def test_str_schema_with_tools():
97+
"""Test that str output_schema + tools produces output (not infinite loop)."""
98+
agent = LlmAgent(
99+
name='analyst',
100+
model='gemini-2.5-flash',
101+
instruction='Search for the data, then provide a brief text summary.',
102+
output_schema=str,
103+
tools=[search_data],
104+
)
105+
106+
session_service = InMemorySessionService()
107+
runner = Runner(
108+
agent=agent, app_name='test_app', session_service=session_service
109+
)
110+
session = await session_service.create_session(
111+
app_name='test_app', user_id='test_user'
112+
)
113+
114+
events = []
115+
start = time.time()
116+
117+
async for event in runner.run_async(
118+
user_id='test_user',
119+
session_id=session.id,
120+
new_message=types.Content(
121+
role='user',
122+
parts=[types.Part(text='What is the Q1 revenue?')],
123+
),
124+
):
125+
events.append(event)
126+
127+
elapsed = time.time() - start
128+
129+
assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop'
130+
assert len(events) > 0, 'No events produced'
131+
print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events')

0 commit comments

Comments
 (0)