Skip to content

Commit b270591

Browse files
committed
fix: validate resumed tool calls against prior state
1 parent 029b87d commit b270591

7 files changed

Lines changed: 357 additions & 79 deletions

File tree

src/google/adk/auth/auth_preprocessor.py

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,44 @@
3737
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
3838

3939

40+
def _find_function_call(
41+
events: list[Event], function_call_id: str
42+
) -> Any | None:
43+
for event in events:
44+
for function_call in event.get_function_calls():
45+
if function_call.id == function_call_id:
46+
return function_call
47+
return None
48+
49+
50+
def _has_requested_auth_config(
51+
events: list[Event], function_call_id: str
52+
) -> bool:
53+
return any(
54+
function_call_id in event.actions.requested_auth_configs
55+
for event in events
56+
)
57+
58+
59+
def _is_valid_auth_resume_target(
60+
events: list[Event], args: AuthToolArguments
61+
) -> bool:
62+
if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX):
63+
return False
64+
if not _has_requested_auth_config(events, args.function_call_id):
65+
return False
66+
if not args.function_call_digest:
67+
return False
68+
69+
function_call = _find_function_call(events, args.function_call_id)
70+
if not function_call:
71+
return False
72+
73+
return (
74+
functions.function_call_digest(function_call) == args.function_call_digest
75+
)
76+
77+
4078
async def _store_auth_and_collect_resume_targets(
4179
events: list[Event],
4280
auth_fc_ids: set[str],
@@ -64,62 +102,47 @@ async def _store_auth_and_collect_resume_targets(
64102
"""
65103
# Step 1: Scan events for matching adk_request_credential function calls
66104
# to extract AuthToolArguments (contains credential_key).
67-
requested_auth_config_by_id: dict[str, AuthConfig] = {}
105+
requested_auth_args_by_id: dict[str, AuthToolArguments] = {}
68106
for event in events:
69107
event_function_calls = event.get_function_calls()
70108
if not event_function_calls:
71109
continue
72-
try:
73-
for function_call in event_function_calls:
74-
if (
75-
function_call.id in auth_fc_ids
76-
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
77-
):
78-
args = AuthToolArguments.model_validate(function_call.args)
79-
requested_auth_config_by_id[function_call.id] = args.auth_config
80-
except TypeError:
81-
continue
110+
for function_call in event_function_calls:
111+
if (
112+
function_call.id not in auth_fc_ids
113+
or function_call.name != REQUEST_EUC_FUNCTION_CALL_NAME
114+
):
115+
continue
116+
try:
117+
requested_auth_args_by_id[function_call.id] = (
118+
AuthToolArguments.model_validate(function_call.args)
119+
)
120+
except TypeError:
121+
continue
82122

83123
# Step 2: Store credentials. Merge credential_key from the original
84124
# request into the client's auth response before storing.
85125
for fc_id in auth_fc_ids:
86126
if fc_id not in auth_responses:
87127
continue
88128
auth_config = AuthConfig.model_validate(auth_responses[fc_id])
89-
requested_auth_config = requested_auth_config_by_id.get(fc_id)
90-
if (
91-
requested_auth_config
92-
and requested_auth_config.credential_key is not None
129+
requested_auth_args = requested_auth_args_by_id.get(fc_id)
130+
if requested_auth_args and (
131+
requested_auth_args.auth_config.credential_key is not None
93132
):
94-
auth_config.credential_key = requested_auth_config.credential_key
133+
auth_config.credential_key = (
134+
requested_auth_args.auth_config.credential_key
135+
)
95136
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
96137
state=state
97138
)
98139

99140
# Step 3: Collect original function call IDs to resume, skipping
100141
# toolset auth entries which don't map to a resumable function call.
101142
tools_to_resume: set[str] = set()
102-
for fc_id in auth_fc_ids:
103-
requested_auth_config = requested_auth_config_by_id.get(fc_id)
104-
if not requested_auth_config:
105-
continue
106-
# Re-parse to get function_call_id (AuthConfig doesn't carry it;
107-
# AuthToolArguments does).
108-
for event in events:
109-
event_function_calls = event.get_function_calls()
110-
if not event_function_calls:
111-
continue
112-
for function_call in event_function_calls:
113-
if (
114-
function_call.id == fc_id
115-
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
116-
):
117-
args = AuthToolArguments.model_validate(function_call.args)
118-
if args.function_call_id.startswith(
119-
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
120-
):
121-
continue
122-
tools_to_resume.add(args.function_call_id)
143+
for args in requested_auth_args_by_id.values():
144+
if _is_valid_auth_resume_target(events, args):
145+
tools_to_resume.add(args.function_call_id)
123146

124147
return tools_to_resume
125148

@@ -141,10 +164,12 @@ async def run_async(
141164
# Find the last user-authored event with function responses to
142165
# identify adk_request_credential responses.
143166
last_event_with_content = None
167+
last_event_with_content_index = -1
144168
for i in range(len(events) - 1, -1, -1):
145169
event = events[i]
146170
if event.content is not None:
147171
last_event_with_content = event
172+
last_event_with_content_index = i
148173
break
149174

150175
if not last_event_with_content or last_event_with_content.author != 'user':
@@ -170,16 +195,20 @@ async def run_async(
170195
return
171196

172197
# Store credentials and collect tools to resume.
198+
prior_events = events[:last_event_with_content_index]
173199
tools_to_resume = await _store_auth_and_collect_resume_targets(
174-
events, auth_fc_ids, auth_responses, invocation_context.session.state
200+
prior_events,
201+
auth_fc_ids,
202+
auth_responses,
203+
invocation_context.session.state,
175204
)
176205

177206
if not tools_to_resume:
178207
return
179208

180209
# Find the original function call event and re-execute the tools
181210
# that needed auth.
182-
for i in range(len(events) - 2, -1, -1):
211+
for i in range(last_event_with_content_index - 1, -1, -1):
183212
event = events[i]
184213
function_calls = event.get_function_calls()
185214
if not function_calls:

src/google/adk/auth/auth_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ class AuthToolArguments(BaseModelWithConfig):
146146

147147
function_call_id: str
148148
auth_config: AuthConfig
149+
function_call_digest: Optional[str] = None

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,7 @@ async def _postprocess_handle_function_calls_async(
11181118
invocation_context, function_call_event, llm_request.tools_dict
11191119
):
11201120
auth_event = functions.generate_auth_event(
1121-
invocation_context, function_response_event
1121+
invocation_context, function_call_event, function_response_event
11221122
)
11231123
if auth_event:
11241124
yield auth_event

src/google/adk/flows/llm_flows/functions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from concurrent.futures import ThreadPoolExecutor
2323
import contextvars
2424
import copy
25+
import hashlib
2526
import inspect
27+
import json
2628
import logging
2729
import threading
2830
from typing import Any
@@ -181,6 +183,20 @@ def generate_client_function_call_id() -> str:
181183
return f'{AF_FUNCTION_CALL_ID_PREFIX}{platform_uuid.new_uuid()}'
182184

183185

186+
def function_call_digest(function_call: types.FunctionCall) -> str:
187+
"""Returns a stable digest for a function call."""
188+
dumped = function_call.model_dump(
189+
exclude_none=True, by_alias=True, mode='json'
190+
)
191+
canonical_json = json.dumps(
192+
dumped,
193+
sort_keys=True,
194+
ensure_ascii=False,
195+
separators=(',', ':'),
196+
)
197+
return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest()
198+
199+
184200
def populate_client_function_call_id(model_response_event: Event) -> None:
185201
if not model_response_event.get_function_calls():
186202
return
@@ -235,6 +251,7 @@ def build_auth_request_event(
235251
*,
236252
author: Optional[str] = None,
237253
role: Optional[str] = None,
254+
function_call_digest_by_id: Optional[dict[str, str]] = None,
238255
) -> Event:
239256
"""Builds an auth request event with function calls for each auth request.
240257
@@ -246,6 +263,8 @@ def build_auth_request_event(
246263
auth_requests: Dict mapping function_call_id to AuthConfig.
247264
author: The event author. Defaults to agent name.
248265
role: The content role. Defaults to None.
266+
function_call_digest_by_id: Optional mapping of function call IDs to stable
267+
digests for tool-level auth requests.
249268
250269
Returns:
251270
Event with auth request function calls.
@@ -260,6 +279,9 @@ def build_auth_request_event(
260279
args=AuthToolArguments(
261280
function_call_id=function_call_id,
262281
auth_config=auth_config,
282+
function_call_digest=(function_call_digest_by_id or {}).get(
283+
function_call_id
284+
),
263285
).model_dump(exclude_none=True, by_alias=True),
264286
)
265287
long_running_tool_ids.add(request_euc_function_call.id)
@@ -276,6 +298,7 @@ def build_auth_request_event(
276298

277299
def generate_auth_event(
278300
invocation_context: InvocationContext,
301+
function_call_event: Event,
279302
function_response_event: Event,
280303
) -> Optional[Event]:
281304
"""Generates an auth request event from a function response event.
@@ -285,6 +308,7 @@ def generate_auth_event(
285308
286309
Args:
287310
invocation_context: The invocation context.
311+
function_call_event: The function call event that produced the response.
288312
function_response_event: The function response event with auth requests.
289313
290314
Returns:
@@ -293,10 +317,18 @@ def generate_auth_event(
293317
if not function_response_event.actions.requested_auth_configs:
294318
return None
295319

320+
function_call_digest_by_id = {
321+
function_call.id: function_call_digest(function_call)
322+
for function_call in function_call_event.get_function_calls()
323+
if function_call.id
324+
in function_response_event.actions.requested_auth_configs
325+
}
326+
296327
return build_auth_request_event(
297328
invocation_context,
298329
function_response_event.actions.requested_auth_configs,
299330
role=function_response_event.content.role,
331+
function_call_digest_by_id=function_call_digest_by_id,
300332
)
301333

302334

src/google/adk/flows/llm_flows/request_confirmation.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,26 @@ def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation:
5050
return ToolConfirmation.model_validate(response)
5151

5252

53+
def _has_matching_original_function_call(
54+
events: list[Event], original_function_call: types.FunctionCall
55+
) -> bool:
56+
expected_digest = functions.function_call_digest(original_function_call)
57+
for event in events:
58+
for function_call in event.get_function_calls():
59+
if functions.function_call_digest(function_call) == expected_digest:
60+
return True
61+
return False
62+
63+
64+
def _has_requested_tool_confirmation(
65+
events: list[Event], function_call_id: str
66+
) -> bool:
67+
return any(
68+
function_call_id in event.actions.requested_tool_confirmations
69+
for event in events
70+
)
71+
72+
5373
def _resolve_confirmation_targets(
5474
events: list[Event],
5575
confirmation_fc_ids: set[str],
@@ -82,13 +102,28 @@ def _resolve_confirmation_targets(
82102
for function_call in event_function_calls:
83103
if function_call.id not in confirmation_fc_ids:
84104
continue
105+
if function_call.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
106+
continue
107+
if (
108+
not event.long_running_tool_ids
109+
or function_call.id not in event.long_running_tool_ids
110+
):
111+
continue
85112

86113
args = function_call.args
87114
if 'originalFunctionCall' not in args:
88115
continue
89116
original_function_call = types.FunctionCall(
90117
**args['originalFunctionCall']
91118
)
119+
if not _has_matching_original_function_call(
120+
events, original_function_call
121+
):
122+
continue
123+
if not _has_requested_tool_confirmation(
124+
events, original_function_call.id
125+
):
126+
continue
92127
tool_confirmation_dict[original_function_call.id] = (
93128
confirmations_by_fc_id[function_call.id]
94129
)
@@ -139,9 +174,10 @@ async def run_async(
139174

140175
# Step 2: Resolve confirmation targets using extracted helper.
141176
confirmation_fc_ids = set(confirmations_by_fc_id.keys())
177+
prior_events = events[:confirmation_event_index]
142178
tools_to_resume_with_confirmation, tools_to_resume_with_args = (
143179
_resolve_confirmation_targets(
144-
events, confirmation_fc_ids, confirmations_by_fc_id
180+
prior_events, confirmation_fc_ids, confirmations_by_fc_id
145181
)
146182
)
147183

0 commit comments

Comments
 (0)