diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 76dd2ddab4..7d94828f11 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -37,6 +37,44 @@ TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' +def _find_function_call( + events: list[Event], function_call_id: str +) -> Any | None: + for event in events: + for function_call in event.get_function_calls(): + if function_call.id == function_call_id: + return function_call + return None + + +def _has_requested_auth_config( + events: list[Event], function_call_id: str +) -> bool: + return any( + function_call_id in event.actions.requested_auth_configs + for event in events + ) + + +def _is_valid_auth_resume_target( + events: list[Event], args: AuthToolArguments +) -> bool: + if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX): + return False + if not _has_requested_auth_config(events, args.function_call_id): + return False + if not args.function_call_digest: + return False + + function_call = _find_function_call(events, args.function_call_id) + if not function_call: + return False + + return ( + functions.function_call_digest(function_call) == args.function_call_digest + ) + + async def _store_auth_and_collect_resume_targets( events: list[Event], auth_fc_ids: set[str], @@ -64,21 +102,23 @@ async def _store_auth_and_collect_resume_targets( """ # Step 1: Scan events for matching adk_request_credential function calls # to extract AuthToolArguments (contains credential_key). - requested_auth_config_by_id: dict[str, AuthConfig] = {} + requested_auth_args_by_id: dict[str, AuthToolArguments] = {} for event in events: event_function_calls = event.get_function_calls() if not event_function_calls: continue - try: - for function_call in event_function_calls: - if ( - function_call.id in auth_fc_ids - and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): - args = AuthToolArguments.model_validate(function_call.args) - requested_auth_config_by_id[function_call.id] = args.auth_config - except TypeError: - continue + for function_call in event_function_calls: + if ( + function_call.id not in auth_fc_ids + or function_call.name != REQUEST_EUC_FUNCTION_CALL_NAME + ): + continue + try: + requested_auth_args_by_id[function_call.id] = ( + AuthToolArguments.model_validate(function_call.args) + ) + except TypeError: + continue # Step 2: Store credentials. Merge credential_key from the original # request into the client's auth response before storing. @@ -86,12 +126,13 @@ async def _store_auth_and_collect_resume_targets( if fc_id not in auth_responses: continue auth_config = AuthConfig.model_validate(auth_responses[fc_id]) - requested_auth_config = requested_auth_config_by_id.get(fc_id) - if ( - requested_auth_config - and requested_auth_config.credential_key is not None + requested_auth_args = requested_auth_args_by_id.get(fc_id) + if requested_auth_args and ( + requested_auth_args.auth_config.credential_key is not None ): - auth_config.credential_key = requested_auth_config.credential_key + auth_config.credential_key = ( + requested_auth_args.auth_config.credential_key + ) await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( state=state ) @@ -99,27 +140,9 @@ async def _store_auth_and_collect_resume_targets( # Step 3: Collect original function call IDs to resume, skipping # toolset auth entries which don't map to a resumable function call. tools_to_resume: set[str] = set() - for fc_id in auth_fc_ids: - requested_auth_config = requested_auth_config_by_id.get(fc_id) - if not requested_auth_config: - continue - # Re-parse to get function_call_id (AuthConfig doesn't carry it; - # AuthToolArguments does). - for event in events: - event_function_calls = event.get_function_calls() - if not event_function_calls: - continue - for function_call in event_function_calls: - if ( - function_call.id == fc_id - and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): - args = AuthToolArguments.model_validate(function_call.args) - if args.function_call_id.startswith( - TOOLSET_AUTH_CREDENTIAL_ID_PREFIX - ): - continue - tools_to_resume.add(args.function_call_id) + for args in requested_auth_args_by_id.values(): + if _is_valid_auth_resume_target(events, args): + tools_to_resume.add(args.function_call_id) return tools_to_resume @@ -141,10 +164,12 @@ async def run_async( # Find the last user-authored event with function responses to # identify adk_request_credential responses. last_event_with_content = None + last_event_with_content_index = -1 for i in range(len(events) - 1, -1, -1): event = events[i] if event.content is not None: last_event_with_content = event + last_event_with_content_index = i break if not last_event_with_content or last_event_with_content.author != 'user': @@ -170,8 +195,12 @@ async def run_async( return # Store credentials and collect tools to resume. + prior_events = events[:last_event_with_content_index] tools_to_resume = await _store_auth_and_collect_resume_targets( - events, auth_fc_ids, auth_responses, invocation_context.session.state + prior_events, + auth_fc_ids, + auth_responses, + invocation_context.session.state, ) if not tools_to_resume: @@ -179,7 +208,7 @@ async def run_async( # Find the original function call event and re-execute the tools # that needed auth. - for i in range(len(events) - 2, -1, -1): + for i in range(last_event_with_content_index - 1, -1, -1): event = events[i] function_calls = event.get_function_calls() if not function_calls: diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index 820540ef12..7171872040 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -146,3 +146,4 @@ class AuthToolArguments(BaseModelWithConfig): function_call_id: str auth_config: AuthConfig + function_call_digest: Optional[str] = None diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 31f998a588..ce69c1f1d2 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1118,7 +1118,7 @@ async def _postprocess_handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): auth_event = functions.generate_auth_event( - invocation_context, function_response_event + invocation_context, function_call_event, function_response_event ) if auth_event: yield auth_event diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ea3b3e76c3..4f2c644d53 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -22,7 +22,9 @@ from concurrent.futures import ThreadPoolExecutor import contextvars import copy +import hashlib import inspect +import json import logging import threading from typing import Any @@ -181,6 +183,20 @@ def generate_client_function_call_id() -> str: return f'{AF_FUNCTION_CALL_ID_PREFIX}{platform_uuid.new_uuid()}' +def function_call_digest(function_call: types.FunctionCall) -> str: + """Returns a stable digest for a function call.""" + dumped = function_call.model_dump( + exclude_none=True, by_alias=True, mode='json' + ) + canonical_json = json.dumps( + dumped, + sort_keys=True, + ensure_ascii=False, + separators=(',', ':'), + ) + return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest() + + def populate_client_function_call_id(model_response_event: Event) -> None: if not model_response_event.get_function_calls(): return @@ -235,6 +251,7 @@ def build_auth_request_event( *, author: Optional[str] = None, role: Optional[str] = None, + function_call_digest_by_id: Optional[dict[str, str]] = None, ) -> Event: """Builds an auth request event with function calls for each auth request. @@ -246,6 +263,8 @@ def build_auth_request_event( auth_requests: Dict mapping function_call_id to AuthConfig. author: The event author. Defaults to agent name. role: The content role. Defaults to None. + function_call_digest_by_id: Optional mapping of function call IDs to stable + digests for tool-level auth requests. Returns: Event with auth request function calls. @@ -260,6 +279,9 @@ def build_auth_request_event( args=AuthToolArguments( function_call_id=function_call_id, auth_config=auth_config, + function_call_digest=(function_call_digest_by_id or {}).get( + function_call_id + ), ).model_dump(exclude_none=True, by_alias=True), ) long_running_tool_ids.add(request_euc_function_call.id) @@ -276,6 +298,7 @@ def build_auth_request_event( def generate_auth_event( invocation_context: InvocationContext, + function_call_event: Event, function_response_event: Event, ) -> Optional[Event]: """Generates an auth request event from a function response event. @@ -285,6 +308,7 @@ def generate_auth_event( Args: invocation_context: The invocation context. + function_call_event: The function call event that produced the response. function_response_event: The function response event with auth requests. Returns: @@ -293,10 +317,18 @@ def generate_auth_event( if not function_response_event.actions.requested_auth_configs: return None + function_call_digest_by_id = { + function_call.id: function_call_digest(function_call) + for function_call in function_call_event.get_function_calls() + if function_call.id + in function_response_event.actions.requested_auth_configs + } + return build_auth_request_event( invocation_context, function_response_event.actions.requested_auth_configs, role=function_response_event.content.role, + function_call_digest_by_id=function_call_digest_by_id, ) diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index d066db791d..e2c296a3ad 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -50,6 +50,26 @@ def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation: return ToolConfirmation.model_validate(response) +def _has_matching_original_function_call( + events: list[Event], original_function_call: types.FunctionCall +) -> bool: + expected_digest = functions.function_call_digest(original_function_call) + for event in events: + for function_call in event.get_function_calls(): + if functions.function_call_digest(function_call) == expected_digest: + return True + return False + + +def _has_requested_tool_confirmation( + events: list[Event], function_call_id: str +) -> bool: + return any( + function_call_id in event.actions.requested_tool_confirmations + for event in events + ) + + def _resolve_confirmation_targets( events: list[Event], confirmation_fc_ids: set[str], @@ -82,6 +102,13 @@ def _resolve_confirmation_targets( for function_call in event_function_calls: if function_call.id not in confirmation_fc_ids: continue + if function_call.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: + continue + if ( + not event.long_running_tool_ids + or function_call.id not in event.long_running_tool_ids + ): + continue args = function_call.args if 'originalFunctionCall' not in args: @@ -89,6 +116,14 @@ def _resolve_confirmation_targets( original_function_call = types.FunctionCall( **args['originalFunctionCall'] ) + if not _has_matching_original_function_call( + events, original_function_call + ): + continue + if not _has_requested_tool_confirmation( + events, original_function_call.id + ): + continue tool_confirmation_dict[original_function_call.id] = ( confirmations_by_fc_id[function_call.id] ) @@ -139,9 +174,10 @@ async def run_async( # Step 2: Resolve confirmation targets using extracted helper. confirmation_fc_ids = set(confirmations_by_fc_id.keys()) + prior_events = events[:confirmation_event_index] tools_to_resume_with_confirmation, tools_to_resume_with_args = ( _resolve_confirmation_targets( - events, confirmation_fc_ids, confirmations_by_fc_id + prior_events, confirmation_fc_ids, confirmations_by_fc_id ) ) diff --git a/tests/unittests/auth/test_auth_preprocessor.py b/tests/unittests/auth/test_auth_preprocessor.py index fb45cc34ac..74ffb8b68b 100644 --- a/tests/unittests/auth/test_auth_preprocessor.py +++ b/tests/unittests/auth/test_auth_preprocessor.py @@ -23,11 +23,15 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.auth.auth_handler import AuthHandler from google.adk.auth.auth_preprocessor import _AuthLlmRequestProcessor +from google.adk.auth.auth_schemes import CustomAuthScheme from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.auth_tool import AuthToolArguments from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.flows.llm_flows import functions from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from google.adk.models.llm_request import LlmRequest +from google.genai import types import pytest @@ -320,8 +324,13 @@ async def test_processes_auth_response_successfully( @patch('google.adk.auth.auth_preprocessor.AuthHandler') @patch('google.adk.auth.auth_tool.AuthConfig.model_validate') @patch('google.adk.flows.llm_flows.functions.handle_function_calls_async') + @patch( + 'google.adk.auth.auth_preprocessor._is_valid_auth_resume_target', + return_value=True, + ) async def test_processes_multiple_auth_responses_and_resumes_tools( self, + mock_is_valid_auth_resume_target, mock_handle_function_calls, mock_auth_config_validate, mock_auth_handler_class, @@ -413,6 +422,7 @@ async def test_processes_multiple_auth_responses_and_resumes_tools( # Verify auth responses were processed assert mock_auth_handler.parse_and_store_auth_response.call_count == 2 + assert mock_is_valid_auth_resume_target.call_count == 2 # Verify function calls were resumed mock_handle_function_calls.assert_called_once() @@ -423,6 +433,100 @@ async def test_processes_multiple_auth_responses_and_resumes_tools( # Verify the function response event was yielded assert result == [mock_function_response_event] + @pytest.mark.asyncio + @patch('google.adk.auth.auth_preprocessor.AuthHandler') + @patch('google.adk.flows.llm_flows.functions.handle_function_calls_async') + async def test_ignores_tampered_original_function_call_on_resume( + self, + mock_handle_function_calls, + mock_auth_handler_class, + processor, + mock_invocation_context, + mock_llm_request, + ): + """Test that auth resume refuses a tampered original function call.""" + auth_config = AuthConfig(auth_scheme=CustomAuthScheme(type='custom_auth')) + original_function_call = types.FunctionCall( + id='tool_id_1', + name='read_file', + args={'path': '/home/victim/notes.txt'}, + ) + original_event = Event( + author='test_agent', + content=types.Content( + parts=[types.Part(function_call=original_function_call)] + ), + ) + mock_invocation_context.invocation_id = 'test_invocation_id' + mock_invocation_context.branch = None + auth_request_event = functions.build_auth_request_event( + mock_invocation_context, + {'tool_id_1': auth_config}, + author='test_agent', + function_call_digest_by_id={ + 'tool_id_1': functions.function_call_digest(original_function_call) + }, + ) + paused_response_event = Event( + author='test_agent', + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + id='tool_id_1', + name='read_file', + response={'error': 'Auth required.'}, + ) + ) + ] + ), + actions=EventActions(requested_auth_configs={'tool_id_1': auth_config}), + ) + + # Simulate a storage-layer mutation of the original tool-call event. + original_function_call.name = 'delete_user_account' + original_function_call.args = {'user_id': 'victim_user'} + + auth_request_function_call = auth_request_event.get_function_calls()[0] + user_auth_response_event = Event( + author='user', + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + id=auth_request_function_call.id, + name=REQUEST_EUC_FUNCTION_CALL_NAME, + response=auth_config.model_dump( + mode='json', by_alias=True, exclude_none=True + ), + ) + ) + ] + ), + ) + mock_invocation_context.session.events = [ + original_event, + auth_request_event, + paused_response_event, + user_auth_response_event, + ] + + mock_auth_handler = Mock(spec=AuthHandler) + mock_auth_handler.parse_and_store_auth_response = AsyncMock() + mock_auth_handler_class.return_value = mock_auth_handler + + result = [] + async for event in processor.run_async( + mock_invocation_context, mock_llm_request + ): + result.append(event) + + assert result == [] + mock_auth_handler.parse_and_store_auth_response.assert_called_once_with( + state=mock_invocation_context.session.state + ) + mock_handle_function_calls.assert_not_called() + @pytest.mark.asyncio @patch('google.adk.auth.auth_preprocessor.AuthHandler') @patch('google.adk.auth.auth_tool.AuthConfig.model_validate') diff --git a/tests/unittests/flows/llm_flows/test_request_confirmation.py b/tests/unittests/flows/llm_flows/test_request_confirmation.py index 39b35454b7..02cb2d38dc 100644 --- a/tests/unittests/flows/llm_flows/test_request_confirmation.py +++ b/tests/unittests/flows/llm_flows/test_request_confirmation.py @@ -17,6 +17,7 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.flows.llm_flows import functions from google.adk.flows.llm_flows.request_confirmation import request_processor from google.adk.models.llm_request import LlmRequest @@ -36,6 +37,68 @@ def mock_tool(param1: str): return f"Mock tool result with {param1}" +def _append_confirmation_request_events( + invocation_context, + original_function_call: types.FunctionCall, + tool_confirmation: ToolConfirmation, +) -> None: + tool_confirmation_args = { + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + } + + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[types.Part(function_call=original_function_call)] + ), + ) + ) + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args=tool_confirmation_args, + id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + ) + ) + ] + ), + long_running_tool_ids={MOCK_CONFIRMATION_FUNCTION_CALL_ID}, + ) + ) + invocation_context.session.events.append( + Event( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=original_function_call.name, + id=original_function_call.id, + response={"error": "Tool execution paused."}, + ) + ) + ] + ), + actions=EventActions( + requested_tool_confirmations={ + original_function_call.id: tool_confirmation + } + ), + ) + ) + + @pytest.mark.asyncio async def test_request_confirmation_processor_no_events(): """Test that the processor returns None when there are no events.""" @@ -123,31 +186,8 @@ async def test_request_confirmation_processor_success(): ) tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") - tool_confirmation_args = { - "originalFunctionCall": original_function_call.model_dump( - exclude_none=True, by_alias=True - ), - "toolConfirmation": tool_confirmation.model_dump( - by_alias=True, exclude_none=True - ), - } - - # Event with the request for confirmation - invocation_context.session.events.append( - Event( - author="agent", - content=types.Content( - parts=[ - types.Part( - function_call=types.FunctionCall( - name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, - args=tool_confirmation_args, - id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, - ) - ) - ] - ), - ) + _append_confirmation_request_events( + invocation_context, original_function_call, tool_confirmation ) # Event with the user's confirmation @@ -211,8 +251,8 @@ async def test_request_confirmation_processor_success(): @pytest.mark.asyncio -async def test_request_confirmation_processor_tool_not_confirmed(): - """Test when the tool execution is not confirmed by the user.""" +async def test_request_confirmation_processor_ignores_tampered_original_call(): + """Test that tampered confirmation state does not resume another tool.""" agent = LlmAgent(name="test_agent", tools=[mock_tool]) invocation_context = await testing_utils.create_invocation_context( agent=agent @@ -222,27 +262,32 @@ async def test_request_confirmation_processor_tool_not_confirmed(): original_function_call = types.FunctionCall( name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID ) - tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") - tool_confirmation_args = { - "originalFunctionCall": original_function_call.model_dump( - exclude_none=True, by_alias=True - ), - "toolConfirmation": tool_confirmation.model_dump( - by_alias=True, exclude_none=True - ), + _append_confirmation_request_events( + invocation_context, original_function_call, tool_confirmation + ) + + confirmation_event = invocation_context.session.events[1] + confirmation_function_call = confirmation_event.get_function_calls()[0] + confirmation_function_call.args["originalFunctionCall"] = { + "id": MOCK_FUNCTION_CALL_ID, + "name": "delete_user_account", + "args": {"user_id": "victim_user"}, } + user_confirmation = ToolConfirmation(confirmed=True) invocation_context.session.events.append( Event( - author="agent", + author="user", content=types.Content( parts=[ types.Part( - function_call=types.FunctionCall( + function_response=types.FunctionResponse( name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, - args=tool_confirmation_args, id=MOCK_CONFIRMATION_FUNCTION_CALL_ID, + response={ + "response": user_confirmation.model_dump_json() + }, ) ) ] @@ -250,6 +295,37 @@ async def test_request_confirmation_processor_tool_not_confirmed(): ) ) + with patch( + "google.adk.flows.llm_flows.functions.handle_function_call_list_async" + ) as mock_handle_function_call_list_async: + events = [] + async for event in request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + assert not events + mock_handle_function_call_list_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_request_confirmation_processor_tool_not_confirmed(): + """Test when the tool execution is not confirmed by the user.""" + agent = LlmAgent(name="test_agent", tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + llm_request = LlmRequest() + + original_function_call = types.FunctionCall( + name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID + ) + + tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint") + _append_confirmation_request_events( + invocation_context, original_function_call, tool_confirmation + ) + user_confirmation = ToolConfirmation(confirmed=False) invocation_context.session.events.append( Event(