Skip to content

Commit c36a708

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Support before_tool_callback and after_tool_callback in Live mode
Close #4704 Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 878662637
1 parent 45fb53b commit c36a708

2 files changed

Lines changed: 222 additions & 31 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -730,41 +730,77 @@ async def _run_with_trace():
730730
# Make a deep copy to avoid being modified.
731731
function_response = None
732732

733-
# Handle before_tool_callbacks - iterate through the canonical callback
734-
# list
735-
for callback in agent.canonical_before_tool_callbacks:
736-
function_response = callback(
737-
tool=tool, args=function_args, tool_context=tool_context
738-
)
739-
if inspect.isawaitable(function_response):
740-
function_response = await function_response
741-
if function_response:
742-
break
733+
# Step 1: Check if plugin before_tool_callback overrides the function
734+
# response.
735+
function_response = (
736+
await invocation_context.plugin_manager.run_before_tool_callback(
737+
tool=tool, tool_args=function_args, tool_context=tool_context
738+
)
739+
)
743740

741+
# Step 2: If no overrides are provided from the plugins, further run the
742+
# canonical callback.
744743
if function_response is None:
745-
function_response = await _process_function_live_helper(
746-
tool,
747-
tool_context,
748-
function_call,
749-
function_args,
750-
invocation_context,
751-
streaming_lock,
752-
)
744+
for callback in agent.canonical_before_tool_callbacks:
745+
function_response = callback(
746+
tool=tool, args=function_args, tool_context=tool_context
747+
)
748+
if inspect.isawaitable(function_response):
749+
function_response = await function_response
750+
if function_response:
751+
break
753752

754-
# Calls after_tool_callback if it exists.
755-
altered_function_response = None
756-
for callback in agent.canonical_after_tool_callbacks:
757-
altered_function_response = callback(
758-
tool=tool,
759-
args=function_args,
760-
tool_context=tool_context,
761-
tool_response=function_response,
762-
)
763-
if inspect.isawaitable(altered_function_response):
764-
altered_function_response = await altered_function_response
765-
if altered_function_response:
766-
break
753+
# Step 3: Otherwise, proceed calling the tool normally.
754+
if function_response is None:
755+
try:
756+
function_response = await _process_function_live_helper(
757+
tool,
758+
tool_context,
759+
function_call,
760+
function_args,
761+
invocation_context,
762+
streaming_lock,
763+
)
764+
except Exception as tool_error:
765+
error_response = await _run_on_tool_error_callbacks(
766+
tool=tool,
767+
tool_args=function_args,
768+
tool_context=tool_context,
769+
error=tool_error,
770+
)
771+
if error_response is not None:
772+
function_response = error_response
773+
else:
774+
raise tool_error
767775

776+
# Step 4: Check if plugin after_tool_callback overrides the function
777+
# response.
778+
altered_function_response = (
779+
await invocation_context.plugin_manager.run_after_tool_callback(
780+
tool=tool,
781+
tool_args=function_args,
782+
tool_context=tool_context,
783+
result=function_response,
784+
)
785+
)
786+
787+
# Step 5: If no overrides are provided from the plugins, further run the
788+
# canonical after_tool_callbacks.
789+
if altered_function_response is None:
790+
for callback in agent.canonical_after_tool_callbacks:
791+
altered_function_response = callback(
792+
tool=tool,
793+
args=function_args,
794+
tool_context=tool_context,
795+
tool_response=function_response,
796+
)
797+
if inspect.isawaitable(altered_function_response):
798+
altered_function_response = await altered_function_response
799+
if altered_function_response:
800+
break
801+
802+
# Step 6: If alternative response exists from after_tool_callback, use it
803+
# instead of the original function response.
768804
if altered_function_response is not None:
769805
function_response = altered_function_response
770806

tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.adk.agents.llm_agent import Agent
2020
from google.adk.events.event import Event
2121
from google.adk.flows.llm_flows.functions import handle_function_calls_async
22+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
2223
from google.adk.plugins.base_plugin import BasePlugin
2324
from google.adk.tools.base_tool import BaseTool
2425
from google.adk.tools.function_tool import FunctionTool
@@ -185,5 +186,159 @@ async def test_async_on_tool_error_fallback_to_runner(
185186
assert e == mock_error
186187

187188

189+
async def invoke_tool_with_plugin_live(
190+
mock_tool, mock_plugin
191+
) -> Optional[Event]:
192+
"""Invokes a tool with a plugin using the live path."""
193+
model = testing_utils.MockModel.create(responses=[])
194+
agent = Agent(
195+
name="agent",
196+
model=model,
197+
tools=[mock_tool],
198+
)
199+
invocation_context = await testing_utils.create_invocation_context(
200+
agent=agent, user_content="", plugins=[mock_plugin]
201+
)
202+
# Build function call event
203+
function_call = types.FunctionCall(name=mock_tool.name, args={})
204+
content = types.Content(parts=[types.Part(function_call=function_call)])
205+
event = Event(
206+
invocation_id=invocation_context.invocation_id,
207+
author=agent.name,
208+
content=content,
209+
)
210+
tools_dict = {mock_tool.name: mock_tool}
211+
return await handle_function_calls_live(
212+
invocation_context,
213+
event,
214+
tools_dict,
215+
)
216+
217+
218+
@pytest.mark.asyncio
219+
async def test_live_before_tool_callback(mock_tool, mock_plugin):
220+
mock_plugin.enable_before_tool_callback = True
221+
222+
result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin)
223+
224+
assert result_event is not None
225+
part = result_event.content.parts[0]
226+
assert part.function_response.response == mock_plugin.before_tool_response
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_live_after_tool_callback(mock_tool, mock_plugin):
231+
mock_plugin.enable_after_tool_callback = True
232+
233+
result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin)
234+
235+
assert result_event is not None
236+
part = result_event.content.parts[0]
237+
assert part.function_response.response == mock_plugin.after_tool_response
238+
239+
240+
@pytest.mark.asyncio
241+
async def test_live_on_tool_error_use_plugin_response(
242+
mock_error_tool, mock_plugin
243+
):
244+
mock_plugin.enable_on_tool_error_callback = True
245+
246+
result_event = await invoke_tool_with_plugin_live(
247+
mock_error_tool, mock_plugin
248+
)
249+
250+
assert result_event is not None
251+
part = result_event.content.parts[0]
252+
assert part.function_response.response == mock_plugin.on_tool_error_response
253+
254+
255+
@pytest.mark.asyncio
256+
async def test_live_on_tool_error_fallback_to_runner(
257+
mock_error_tool, mock_plugin
258+
):
259+
mock_plugin.enable_on_tool_error_callback = False
260+
261+
try:
262+
await invoke_tool_with_plugin_live(mock_error_tool, mock_plugin)
263+
except Exception as e:
264+
assert e == mock_error
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_live_plugin_before_tool_callback_takes_priority(
269+
mock_tool, mock_plugin
270+
):
271+
"""Plugin before_tool_callback should run before agent canonical callbacks."""
272+
mock_plugin.enable_before_tool_callback = True
273+
274+
def agent_before_cb(tool, args, tool_context):
275+
return {"agent": "should_not_be_called"}
276+
277+
model = testing_utils.MockModel.create(responses=[])
278+
agent = Agent(
279+
name="agent",
280+
model=model,
281+
tools=[mock_tool],
282+
before_tool_callback=agent_before_cb,
283+
)
284+
invocation_context = await testing_utils.create_invocation_context(
285+
agent=agent, user_content="", plugins=[mock_plugin]
286+
)
287+
function_call = types.FunctionCall(name=mock_tool.name, args={})
288+
content = types.Content(parts=[types.Part(function_call=function_call)])
289+
event = Event(
290+
invocation_id=invocation_context.invocation_id,
291+
author=agent.name,
292+
content=content,
293+
)
294+
tools_dict = {mock_tool.name: mock_tool}
295+
result_event = await handle_function_calls_live(
296+
invocation_context, event, tools_dict
297+
)
298+
299+
assert result_event is not None
300+
part = result_event.content.parts[0]
301+
# Plugin response should win, not the agent callback
302+
assert part.function_response.response == mock_plugin.before_tool_response
303+
304+
305+
@pytest.mark.asyncio
306+
async def test_live_plugin_after_tool_callback_takes_priority(
307+
mock_tool, mock_plugin
308+
):
309+
"""Plugin after_tool_callback should run before agent canonical callbacks."""
310+
mock_plugin.enable_after_tool_callback = True
311+
312+
def agent_after_cb(tool, args, tool_context, tool_response):
313+
return {"agent": "should_not_be_called"}
314+
315+
model = testing_utils.MockModel.create(responses=[])
316+
agent = Agent(
317+
name="agent",
318+
model=model,
319+
tools=[mock_tool],
320+
after_tool_callback=agent_after_cb,
321+
)
322+
invocation_context = await testing_utils.create_invocation_context(
323+
agent=agent, user_content="", plugins=[mock_plugin]
324+
)
325+
function_call = types.FunctionCall(name=mock_tool.name, args={})
326+
content = types.Content(parts=[types.Part(function_call=function_call)])
327+
event = Event(
328+
invocation_id=invocation_context.invocation_id,
329+
author=agent.name,
330+
content=content,
331+
)
332+
tools_dict = {mock_tool.name: mock_tool}
333+
result_event = await handle_function_calls_live(
334+
invocation_context, event, tools_dict
335+
)
336+
337+
assert result_event is not None
338+
part = result_event.content.parts[0]
339+
# Plugin response should win, not the agent callback
340+
assert part.function_response.response == mock_plugin.after_tool_response
341+
342+
188343
if __name__ == "__main__":
189344
pytest.main([__file__])

0 commit comments

Comments
 (0)