Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 68 additions & 39 deletions src/google/adk/auth/auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,44 @@
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'


def _find_function_call(
events: list[Event], function_call_id: str
) -> Any | None:
for event in events:
for function_call in event.get_function_calls():
if function_call.id == function_call_id:
return function_call
return None


def _has_requested_auth_config(
events: list[Event], function_call_id: str
) -> bool:
return any(
function_call_id in event.actions.requested_auth_configs
for event in events
)


def _is_valid_auth_resume_target(
events: list[Event], args: AuthToolArguments
) -> bool:
if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX):
return False
if not _has_requested_auth_config(events, args.function_call_id):
return False
if not args.function_call_digest:
return False

function_call = _find_function_call(events, args.function_call_id)
if not function_call:
return False

return (
functions.function_call_digest(function_call) == args.function_call_digest
)


async def _store_auth_and_collect_resume_targets(
events: list[Event],
auth_fc_ids: set[str],
Expand Down Expand Up @@ -64,62 +102,47 @@ async def _store_auth_and_collect_resume_targets(
"""
# Step 1: Scan events for matching adk_request_credential function calls
# to extract AuthToolArguments (contains credential_key).
requested_auth_config_by_id: dict[str, AuthConfig] = {}
requested_auth_args_by_id: dict[str, AuthToolArguments] = {}
for event in events:
event_function_calls = event.get_function_calls()
if not event_function_calls:
continue
try:
for function_call in event_function_calls:
if (
function_call.id in auth_fc_ids
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
requested_auth_config_by_id[function_call.id] = args.auth_config
except TypeError:
continue
for function_call in event_function_calls:
if (
function_call.id not in auth_fc_ids
or function_call.name != REQUEST_EUC_FUNCTION_CALL_NAME
):
continue
try:
requested_auth_args_by_id[function_call.id] = (
AuthToolArguments.model_validate(function_call.args)
)
except TypeError:
continue

# Step 2: Store credentials. Merge credential_key from the original
# request into the client's auth response before storing.
for fc_id in auth_fc_ids:
if fc_id not in auth_responses:
continue
auth_config = AuthConfig.model_validate(auth_responses[fc_id])
requested_auth_config = requested_auth_config_by_id.get(fc_id)
if (
requested_auth_config
and requested_auth_config.credential_key is not None
requested_auth_args = requested_auth_args_by_id.get(fc_id)
if requested_auth_args and (
requested_auth_args.auth_config.credential_key is not None
):
auth_config.credential_key = requested_auth_config.credential_key
auth_config.credential_key = (
requested_auth_args.auth_config.credential_key
)
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
state=state
)

# Step 3: Collect original function call IDs to resume, skipping
# toolset auth entries which don't map to a resumable function call.
tools_to_resume: set[str] = set()
for fc_id in auth_fc_ids:
requested_auth_config = requested_auth_config_by_id.get(fc_id)
if not requested_auth_config:
continue
# Re-parse to get function_call_id (AuthConfig doesn't carry it;
# AuthToolArguments does).
for event in events:
event_function_calls = event.get_function_calls()
if not event_function_calls:
continue
for function_call in event_function_calls:
if (
function_call.id == fc_id
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
if args.function_call_id.startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
):
continue
tools_to_resume.add(args.function_call_id)
for args in requested_auth_args_by_id.values():
if _is_valid_auth_resume_target(events, args):
tools_to_resume.add(args.function_call_id)

return tools_to_resume

Expand All @@ -141,10 +164,12 @@ async def run_async(
# Find the last user-authored event with function responses to
# identify adk_request_credential responses.
last_event_with_content = None
last_event_with_content_index = -1
for i in range(len(events) - 1, -1, -1):
event = events[i]
if event.content is not None:
last_event_with_content = event
last_event_with_content_index = i
break

if not last_event_with_content or last_event_with_content.author != 'user':
Expand All @@ -170,16 +195,20 @@ async def run_async(
return

# Store credentials and collect tools to resume.
prior_events = events[:last_event_with_content_index]
tools_to_resume = await _store_auth_and_collect_resume_targets(
events, auth_fc_ids, auth_responses, invocation_context.session.state
prior_events,
auth_fc_ids,
auth_responses,
invocation_context.session.state,
)

if not tools_to_resume:
return

# Find the original function call event and re-execute the tools
# that needed auth.
for i in range(len(events) - 2, -1, -1):
for i in range(last_event_with_content_index - 1, -1, -1):
event = events[i]
function_calls = event.get_function_calls()
if not function_calls:
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/auth/auth_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ class AuthToolArguments(BaseModelWithConfig):

function_call_id: str
auth_config: AuthConfig
function_call_digest: Optional[str] = None
2 changes: 1 addition & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ async def _postprocess_handle_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
):
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
invocation_context, function_call_event, function_response_event
)
if auth_event:
yield auth_event
Expand Down
32 changes: 32 additions & 0 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from concurrent.futures import ThreadPoolExecutor
import contextvars
import copy
import hashlib
import inspect
import json
import logging
import threading
from typing import Any
Expand Down Expand Up @@ -181,6 +183,20 @@ def generate_client_function_call_id() -> str:
return f'{AF_FUNCTION_CALL_ID_PREFIX}{platform_uuid.new_uuid()}'


def function_call_digest(function_call: types.FunctionCall) -> str:
"""Returns a stable digest for a function call."""
dumped = function_call.model_dump(
exclude_none=True, by_alias=True, mode='json'
)
canonical_json = json.dumps(
dumped,
sort_keys=True,
ensure_ascii=False,
separators=(',', ':'),
)
return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest()


def populate_client_function_call_id(model_response_event: Event) -> None:
if not model_response_event.get_function_calls():
return
Expand Down Expand Up @@ -235,6 +251,7 @@ def build_auth_request_event(
*,
author: Optional[str] = None,
role: Optional[str] = None,
function_call_digest_by_id: Optional[dict[str, str]] = None,
) -> Event:
"""Builds an auth request event with function calls for each auth request.

Expand All @@ -246,6 +263,8 @@ def build_auth_request_event(
auth_requests: Dict mapping function_call_id to AuthConfig.
author: The event author. Defaults to agent name.
role: The content role. Defaults to None.
function_call_digest_by_id: Optional mapping of function call IDs to stable
digests for tool-level auth requests.

Returns:
Event with auth request function calls.
Expand All @@ -260,6 +279,9 @@ def build_auth_request_event(
args=AuthToolArguments(
function_call_id=function_call_id,
auth_config=auth_config,
function_call_digest=(function_call_digest_by_id or {}).get(
function_call_id
),
).model_dump(exclude_none=True, by_alias=True),
)
long_running_tool_ids.add(request_euc_function_call.id)
Expand All @@ -276,6 +298,7 @@ def build_auth_request_event(

def generate_auth_event(
invocation_context: InvocationContext,
function_call_event: Event,
function_response_event: Event,
) -> Optional[Event]:
"""Generates an auth request event from a function response event.
Expand All @@ -285,6 +308,7 @@ def generate_auth_event(

Args:
invocation_context: The invocation context.
function_call_event: The function call event that produced the response.
function_response_event: The function response event with auth requests.

Returns:
Expand All @@ -293,10 +317,18 @@ def generate_auth_event(
if not function_response_event.actions.requested_auth_configs:
return None

function_call_digest_by_id = {
function_call.id: function_call_digest(function_call)
for function_call in function_call_event.get_function_calls()
if function_call.id
in function_response_event.actions.requested_auth_configs
}

return build_auth_request_event(
invocation_context,
function_response_event.actions.requested_auth_configs,
role=function_response_event.content.role,
function_call_digest_by_id=function_call_digest_by_id,
)


Expand Down
38 changes: 37 additions & 1 deletion src/google/adk/flows/llm_flows/request_confirmation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation:
return ToolConfirmation.model_validate(response)


def _has_matching_original_function_call(
events: list[Event], original_function_call: types.FunctionCall
) -> bool:
expected_digest = functions.function_call_digest(original_function_call)
for event in events:
for function_call in event.get_function_calls():
if functions.function_call_digest(function_call) == expected_digest:
return True
return False


def _has_requested_tool_confirmation(
events: list[Event], function_call_id: str
) -> bool:
return any(
function_call_id in event.actions.requested_tool_confirmations
for event in events
)


def _resolve_confirmation_targets(
events: list[Event],
confirmation_fc_ids: set[str],
Expand Down Expand Up @@ -82,13 +102,28 @@ def _resolve_confirmation_targets(
for function_call in event_function_calls:
if function_call.id not in confirmation_fc_ids:
continue
if function_call.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
continue
if (
not event.long_running_tool_ids
or function_call.id not in event.long_running_tool_ids
):
continue

args = function_call.args
if 'originalFunctionCall' not in args:
continue
original_function_call = types.FunctionCall(
**args['originalFunctionCall']
)
if not _has_matching_original_function_call(
events, original_function_call
):
continue
if not _has_requested_tool_confirmation(
events, original_function_call.id
):
continue
tool_confirmation_dict[original_function_call.id] = (
confirmations_by_fc_id[function_call.id]
)
Expand Down Expand Up @@ -139,9 +174,10 @@ async def run_async(

# Step 2: Resolve confirmation targets using extracted helper.
confirmation_fc_ids = set(confirmations_by_fc_id.keys())
prior_events = events[:confirmation_event_index]
tools_to_resume_with_confirmation, tools_to_resume_with_args = (
_resolve_confirmation_targets(
events, confirmation_fc_ids, confirmations_by_fc_id
prior_events, confirmation_fc_ids, confirmations_by_fc_id
)
)

Expand Down
Loading