Skip to content

Commit f8e9195

Browse files
maru0804GWeale
authored andcommitted
fix(planners): allow BuiltInPlanner subclasses to override process_planning_response
Merge google#4141 ## Summary Fixes google#4133 ### Problem When users create a subclass of `BuiltInPlanner` and override `process_planning_response()`, the method was never called because the response processor used `isinstance(planner, BuiltInPlanner)` which returns `True` for all subclasses. ### Solution Changed the check to detect whether `process_planning_response` has been overridden: ```python # Before if not planner or isinstance(planner, BuiltInPlanner): return # After if ( not planner or type(planner).process_planning_response is BuiltInPlanner.process_planning_response ): return ``` This ensures: - `BuiltInPlanner` itself is skipped (returns `None`) - Subclasses **without** override are skipped (avoids side effects) - Subclasses **with** override have their method called ### Testing Added 3 new tests: 1. `test_overridden_subclass_process_planning_response_called` - Regression test for google#4133 2. `test_base_builtin_planner_process_planning_response_not_called` - Verifies base class is skipped 3. `test_non_overridden_subclass_process_planning_response_not_called` - Verifies non-overriding subclasses are also skipped Co-authored-by: George Weale <gweale@google.com> COPYBARA_INTEGRATE_REVIEW=google#4141 from maru0804:fix/4133-planner-process-planning-response 8d57323 PiperOrigin-RevId: 932853783
1 parent 8e2b06d commit f8e9195

2 files changed

Lines changed: 97 additions & 1 deletion

File tree

src/google/adk/flows/llm_flows/_nl_planning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ async def run_async(
8282
return
8383

8484
planner = _get_planner(invocation_context)
85-
if not planner or isinstance(planner, BuiltInPlanner):
85+
if (
86+
not planner
87+
or type(planner).process_planning_response
88+
is BuiltInPlanner.process_planning_response
89+
):
8690
return
8791

8892
# Postprocess the LLM response.

tests/unittests/flows/llm_flows/test_nl_planning.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414

1515
"""Unit tests for NL planning logic."""
1616

17+
from typing import List
18+
from typing import Optional
1719
from unittest.mock import MagicMock
20+
from unittest.mock import patch
1821

22+
from google.adk.agents.callback_context import CallbackContext
1923
from google.adk.agents.llm_agent import Agent
2024
from google.adk.flows.llm_flows._nl_planning import request_processor
25+
from google.adk.flows.llm_flows._nl_planning import response_processor
2126
from google.adk.models.llm_request import LlmRequest
27+
from google.adk.models.llm_response import LlmResponse
2228
from google.adk.planners.built_in_planner import BuiltInPlanner
2329
from google.adk.planners.plan_re_act_planner import PlanReActPlanner
2430
from google.genai import types
@@ -126,3 +132,89 @@ async def test_remove_thought_from_request_with_thoughts():
126132
for content in llm_request.contents
127133
for part in content.parts or []
128134
)
135+
136+
137+
class OverriddenBuiltInPlanner(BuiltInPlanner):
138+
"""Subclass that overrides process_planning_response."""
139+
140+
def __init__(self, *, thinking_config: types.ThinkingConfig):
141+
super().__init__(thinking_config=thinking_config)
142+
self.process_planning_response_called = False
143+
self.received_parts = None
144+
145+
def process_planning_response(
146+
self,
147+
callback_context: CallbackContext,
148+
response_parts: List[types.Part],
149+
) -> Optional[List[types.Part]]:
150+
self.process_planning_response_called = True
151+
self.received_parts = response_parts
152+
return response_parts
153+
154+
155+
class NonOverriddenBuiltInPlanner(BuiltInPlanner):
156+
"""Subclass that does NOT override process_planning_response."""
157+
158+
pass
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_overridden_subclass_process_planning_response_called():
163+
"""Test that subclasses overriding process_planning_response have it called.
164+
165+
Regression test for issue #4133.
166+
"""
167+
planner = OverriddenBuiltInPlanner(thinking_config=types.ThinkingConfig())
168+
agent = Agent(name='test_agent', planner=planner)
169+
invocation_context = await testing_utils.create_invocation_context(
170+
agent=agent, user_content='test message'
171+
)
172+
173+
response_parts = [
174+
types.Part(text='thinking...', thought=True),
175+
types.Part(text='Here is my response'),
176+
]
177+
llm_response = LlmResponse(
178+
content=types.Content(role='model', parts=response_parts)
179+
)
180+
181+
async for _ in response_processor.run_async(invocation_context, llm_response):
182+
pass
183+
184+
assert planner.process_planning_response_called
185+
assert planner.received_parts == response_parts
186+
187+
188+
@pytest.mark.asyncio
189+
@pytest.mark.parametrize(
190+
'planner_class',
191+
[BuiltInPlanner, NonOverriddenBuiltInPlanner],
192+
ids=['base_class', 'non_overridden_subclass'],
193+
)
194+
async def test_process_planning_response_not_called_without_override(
195+
planner_class,
196+
):
197+
"""Test that process_planning_response is not called for base or non-overridden subclasses."""
198+
planner = planner_class(thinking_config=types.ThinkingConfig())
199+
agent = Agent(name='test_agent', planner=planner)
200+
invocation_context = await testing_utils.create_invocation_context(
201+
agent=agent, user_content='test message'
202+
)
203+
204+
response_parts = [
205+
types.Part(text='thinking...', thought=True),
206+
types.Part(text='Here is my response'),
207+
]
208+
llm_response = LlmResponse(
209+
content=types.Content(role='model', parts=response_parts)
210+
)
211+
212+
with patch.object(
213+
BuiltInPlanner,
214+
'process_planning_response',
215+
) as mock_method:
216+
async for _ in response_processor.run_async(
217+
invocation_context, llm_response
218+
):
219+
pass
220+
mock_method.assert_not_called()

0 commit comments

Comments
 (0)