Skip to content

Commit c723201

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Copy the original function call args before passing it to callback or tools to avoid being modified
PiperOrigin-RevId: 788192459
1 parent f29ab5d commit c723201

2 files changed

Lines changed: 172 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ async def handle_function_calls_async(
152152
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
153153
# do not use "args" as the variable name, because it is a reserved keyword
154154
# in python debugger.
155-
function_args = function_call.args or {}
155+
# make a copy to avoid being modified.
156+
function_args = dict(function_call.args) if function_call.args else {}
156157

157158
# Step 1: Check if plugin before_tool_callback overrides the function
158159
# response.
@@ -277,7 +278,8 @@ async def handle_function_calls_live(
277278
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
278279
# do not use "args" as the variable name, because it is a reserved keyword
279280
# in python debugger.
280-
function_args = function_call.args or {}
281+
# make a copy to avoid being modified.
282+
function_args = dict(function_call.args) if function_call.args else {}
281283
function_response = None
282284

283285
# Handle before_tool_callbacks - iterate through the canonical callback

tests/unittests/flows/llm_flows/test_functions_simple.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,171 @@ def test_find_function_call_event_multiple_function_responses():
392392
# Should return the first matching function call event found
393393
result = find_matching_function_call(events)
394394
assert result == call_event1 # First match (func_123)
395+
396+
397+
@pytest.mark.asyncio
398+
async def test_function_call_args_not_modified():
399+
"""Test that function_call.args is not modified when making a copy."""
400+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
401+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
402+
403+
def simple_fn(**kwargs) -> dict:
404+
return {'result': 'test'}
405+
406+
tool = FunctionTool(simple_fn)
407+
model = testing_utils.MockModel.create(responses=[])
408+
agent = Agent(
409+
name='test_agent',
410+
model=model,
411+
tools=[tool],
412+
)
413+
invocation_context = await testing_utils.create_invocation_context(
414+
agent=agent, user_content=''
415+
)
416+
417+
# Create original args that we want to ensure are not modified
418+
original_args = {'param1': 'value1', 'param2': 42}
419+
function_call = types.FunctionCall(name=tool.name, args=original_args)
420+
content = types.Content(parts=[types.Part(function_call=function_call)])
421+
event = Event(
422+
invocation_id=invocation_context.invocation_id,
423+
author=agent.name,
424+
content=content,
425+
)
426+
tools_dict = {tool.name: tool}
427+
428+
# Test handle_function_calls_async
429+
result_async = await handle_function_calls_async(
430+
invocation_context,
431+
event,
432+
tools_dict,
433+
)
434+
435+
# Verify original args are not modified
436+
assert function_call.args == original_args
437+
assert function_call.args is not original_args # Should be a copy
438+
439+
# Test handle_function_calls_live
440+
result_live = await handle_function_calls_live(
441+
invocation_context,
442+
event,
443+
tools_dict,
444+
)
445+
446+
# Verify original args are still not modified
447+
assert function_call.args == original_args
448+
assert function_call.args is not original_args # Should be a copy
449+
450+
# Both should return valid results
451+
assert result_async is not None
452+
assert result_live is not None
453+
454+
455+
@pytest.mark.asyncio
456+
async def test_function_call_args_none_handling():
457+
"""Test that function_call.args=None is handled correctly."""
458+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
459+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
460+
461+
def simple_fn(**kwargs) -> dict:
462+
return {'result': 'test'}
463+
464+
tool = FunctionTool(simple_fn)
465+
model = testing_utils.MockModel.create(responses=[])
466+
agent = Agent(
467+
name='test_agent',
468+
model=model,
469+
tools=[tool],
470+
)
471+
invocation_context = await testing_utils.create_invocation_context(
472+
agent=agent, user_content=''
473+
)
474+
475+
# Create function call with None args
476+
function_call = types.FunctionCall(name=tool.name, args=None)
477+
content = types.Content(parts=[types.Part(function_call=function_call)])
478+
event = Event(
479+
invocation_id=invocation_context.invocation_id,
480+
author=agent.name,
481+
content=content,
482+
)
483+
tools_dict = {tool.name: tool}
484+
485+
# Test handle_function_calls_async
486+
result_async = await handle_function_calls_async(
487+
invocation_context,
488+
event,
489+
tools_dict,
490+
)
491+
492+
# Test handle_function_calls_live
493+
result_live = await handle_function_calls_live(
494+
invocation_context,
495+
event,
496+
tools_dict,
497+
)
498+
499+
# Both should return valid results even with None args
500+
assert result_async is not None
501+
assert result_live is not None
502+
503+
504+
@pytest.mark.asyncio
505+
async def test_function_call_args_copy_behavior():
506+
"""Test that modifying the copied args doesn't affect the original."""
507+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
508+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
509+
510+
def simple_fn(test_param: str, other_param: int) -> dict:
511+
# Modify the args to test that the copy prevents affecting the original
512+
return {
513+
'result': 'test',
514+
'received_args': {'test_param': test_param, 'other_param': other_param},
515+
}
516+
517+
tool = FunctionTool(simple_fn)
518+
model = testing_utils.MockModel.create(responses=[])
519+
agent = Agent(
520+
name='test_agent',
521+
model=model,
522+
tools=[tool],
523+
)
524+
invocation_context = await testing_utils.create_invocation_context(
525+
agent=agent, user_content=''
526+
)
527+
528+
# Create original args
529+
original_args = {'test_param': 'original_value', 'other_param': 123}
530+
function_call = types.FunctionCall(name=tool.name, args=original_args)
531+
content = types.Content(parts=[types.Part(function_call=function_call)])
532+
event = Event(
533+
invocation_id=invocation_context.invocation_id,
534+
author=agent.name,
535+
content=content,
536+
)
537+
tools_dict = {tool.name: tool}
538+
539+
# Test handle_function_calls_async
540+
result_async = await handle_function_calls_async(
541+
invocation_context,
542+
event,
543+
tools_dict,
544+
)
545+
546+
# Verify original args are unchanged
547+
assert function_call.args == original_args
548+
assert function_call.args['test_param'] == 'original_value'
549+
550+
# Verify the tool received the args correctly
551+
assert result_async is not None
552+
response = result_async.content.parts[0].function_response.response
553+
554+
# Check if the response has the expected structure
555+
assert 'received_args' in response
556+
received_args = response['received_args']
557+
assert 'test_param' in received_args
558+
assert received_args['test_param'] == 'original_value'
559+
assert received_args['other_param'] == 123
560+
assert (
561+
function_call.args['test_param'] == 'original_value'
562+
) # Original unchanged

0 commit comments

Comments
 (0)