Skip to content

Commit ee873ca

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat(auth): Add framework support for toolset authentication before get_tools
This change adds framework-level support for resolving toolset authentication before calling get_tools(). Key changes: - Add _resolve_toolset_auth() method in BaseLlmFlow that iterates through toolsets, checks for auth config, and resolves credentials via CredentialManager before tool listing - Add TOOLSET_AUTH_CREDENTIAL_ID_PREFIX constant for identifying toolset auth requests - Add skip logic in auth_preprocessor to not resume function calls for toolset auth (they do not need it) - Add get_auth_response() method to CallbackContext for retrieving auth credentials from session state - Update CredentialManager to accept CallbackContext instead of requiring ToolContext When a toolset needs authentication but credentials are not available, the flow yields an adk_request_credential event and interrupts the invocation, allowing the user to complete the OAuth flow before retrying. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 863543036
1 parent fe82f3c commit ee873ca

File tree

6 files changed

+612
-32
lines changed

6 files changed

+612
-32
lines changed

src/google/adk/agents/callback_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,25 @@ async def load_credential(
177177
auth_config, self
178178
)
179179

180+
def get_auth_response(
181+
self, auth_config: AuthConfig
182+
) -> Optional[AuthCredential]:
183+
"""Gets the auth response credential from session state.
184+
185+
This method retrieves an authentication credential that was previously
186+
stored in session state after a user completed an OAuth flow or other
187+
authentication process.
188+
189+
Args:
190+
auth_config: The authentication configuration for the credential.
191+
192+
Returns:
193+
The auth credential from the auth response, or None if not found.
194+
"""
195+
from ..auth.auth_handler import AuthHandler
196+
197+
return AuthHandler(auth_config).get_auth_response(self.state)
198+
180199
async def add_session_to_memory(self) -> None:
181200
"""Triggers memory generation for the current session.
182201

src/google/adk/auth/auth_preprocessor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
if TYPE_CHECKING:
3434
from ..agents.llm_agent import LlmAgent
3535

36+
# Prefix used by toolset auth credential IDs.
37+
# Auth requests with this prefix are for toolset authentication (before tool
38+
# listing) and don't require resuming a function call.
39+
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
40+
3641

3742
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
3843
"""Handles auth information to build the LLM request."""
@@ -96,6 +101,11 @@ async def run_async(
96101
continue
97102
args = AuthToolArguments.model_validate(function_call.args)
98103

104+
# Skip toolset auth - auth response is already stored in session state
105+
# and we don't need to resume a function call for toolsets
106+
if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX):
107+
continue
108+
99109
tools_to_resume.add(args.function_call_id)
100110
if not tools_to_resume:
101111
continue

src/google/adk/auth/credential_manager.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
from fastapi.openapi.models import OAuth2
2121

22+
from ..agents.callback_context import CallbackContext
2223
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
23-
from ..tools.tool_context import ToolContext
2424
from ..utils.feature_decorator import experimental
2525
from .auth_credential import AuthCredential
2626
from .auth_credential import AuthCredentialTypes
@@ -124,11 +124,16 @@ def register_credential_exchanger(
124124
"""
125125
self._exchanger_registry.register(credential_type, exchanger_instance)
126126

127-
async def request_credential(self, tool_context: ToolContext) -> None:
128-
tool_context.request_credential(self._auth_config)
127+
async def request_credential(self, context: CallbackContext) -> None:
128+
if not hasattr(context, "request_credential"):
129+
raise TypeError(
130+
"request_credential requires a ToolContext with request_credential"
131+
" method, not a plain CallbackContext"
132+
)
133+
context.request_credential(self._auth_config)
129134

130135
async def get_auth_credential(
131-
self, tool_context: ToolContext
136+
self, context: CallbackContext
132137
) -> Optional[AuthCredential]:
133138
"""Load and prepare authentication credential through a structured workflow."""
134139

@@ -140,14 +145,14 @@ async def get_auth_credential(
140145
return self._auth_config.raw_auth_credential
141146

142147
# Step 3: Try to load existing processed credential
143-
credential = await self._load_existing_credential(tool_context)
148+
credential = await self._load_existing_credential(context)
144149

145150
# Step 4: If no existing credential, load from auth response
146151
# TODO instead of load from auth response, we can store auth response in
147152
# credential service.
148153
was_from_auth_response = False
149154
if not credential:
150-
credential = await self._load_from_auth_response(tool_context)
155+
credential = await self._load_from_auth_response(context)
151156
was_from_auth_response = True
152157

153158
# Step 5: If still no credential available, check if client credentials
@@ -169,38 +174,38 @@ async def get_auth_credential(
169174

170175
# Step 8: Save credential if it was modified
171176
if was_from_auth_response or was_exchanged or was_refreshed:
172-
await self._save_credential(tool_context, credential)
177+
await self._save_credential(context, credential)
173178

174179
return credential
175180

176181
async def _load_existing_credential(
177-
self, tool_context: ToolContext
182+
self, context: CallbackContext
178183
) -> Optional[AuthCredential]:
179184
"""Load existing credential from credential service."""
180185

181186
# Try loading from credential service first
182-
credential = await self._load_from_credential_service(tool_context)
187+
credential = await self._load_from_credential_service(context)
183188
if credential:
184189
return credential
185190

186191
return None
187192

188193
async def _load_from_credential_service(
189-
self, tool_context: ToolContext
194+
self, context: CallbackContext
190195
) -> Optional[AuthCredential]:
191196
"""Load credential from credential service if available."""
192-
credential_service = tool_context._invocation_context.credential_service
197+
credential_service = context._invocation_context.credential_service
193198
if credential_service:
194199
# Note: This should be made async in a future refactor
195200
# For now, assuming synchronous operation
196-
return await tool_context.load_credential(self._auth_config)
201+
return await context.load_credential(self._auth_config)
197202
return None
198203

199204
async def _load_from_auth_response(
200-
self, tool_context: ToolContext
205+
self, context: CallbackContext
201206
) -> Optional[AuthCredential]:
202-
"""Load credential from auth response in tool context."""
203-
return tool_context.get_auth_response(self._auth_config)
207+
"""Load credential from auth response in context."""
208+
return context.get_auth_response(self._auth_config)
204209

205210
async def _exchange_credential(
206211
self, credential: AuthCredential
@@ -290,15 +295,15 @@ async def _validate_credential(self) -> None:
290295
# Additional validation can be added here
291296

292297
async def _save_credential(
293-
self, tool_context: ToolContext, credential: AuthCredential
298+
self, context: CallbackContext, credential: AuthCredential
294299
) -> None:
295300
"""Save credential to credential service if available."""
296301
# Update the exchanged credential in config
297302
self._auth_config.exchanged_auth_credential = credential
298303

299-
credential_service = tool_context._invocation_context.credential_service
304+
credential_service = context._invocation_context.credential_service
300305
if credential_service:
301-
await tool_context.save_credential(self._auth_config)
306+
await context.save_credential(self._auth_config)
302307

303308
async def _populate_auth_scheme(self) -> bool:
304309
"""Auto-discover server metadata and populate missing auth scheme info.

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from ...agents.readonly_context import ReadonlyContext
3838
from ...agents.run_config import StreamingMode
3939
from ...agents.transcription_entry import TranscriptionEntry
40+
from ...auth.auth_handler import AuthHandler
41+
from ...auth.auth_tool import AuthConfig
42+
from ...auth.credential_manager import CredentialManager
4043
from ...events.event import Event
4144
from ...models.base_llm_connection import BaseLlmConnection
4245
from ...models.llm_request import LlmRequest
@@ -50,6 +53,11 @@
5053
from ...tools.tool_context import ToolContext
5154
from ...utils.context_utils import Aclosing
5255
from .audio_cache_manager import AudioCacheManager
56+
from .functions import build_auth_request_event
57+
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
58+
59+
# Prefix used by toolset auth credential IDs
60+
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
5361

5462
if TYPE_CHECKING:
5563
from ...agents.llm_agent import LlmAgent
@@ -528,6 +536,17 @@ async def _preprocess_async(
528536
async for event in agen:
529537
yield event
530538

539+
# Resolve toolset authentication before tool listing.
540+
# This ensures credentials are ready before get_tools() is called.
541+
async with Aclosing(
542+
self._resolve_toolset_auth(invocation_context, agent)
543+
) as agen:
544+
async for event in agen:
545+
yield event
546+
547+
if invocation_context.end_invocation:
548+
return
549+
531550
# Run processors for tools.
532551

533552
# We may need to wrap some built-in tools if there are other tools
@@ -561,6 +580,81 @@ async def _preprocess_async(
561580
tool_context=tool_context, llm_request=llm_request
562581
)
563582

583+
async def _resolve_toolset_auth(
584+
self,
585+
invocation_context: InvocationContext,
586+
agent: LlmAgent,
587+
) -> AsyncGenerator[Event, None]:
588+
"""Resolves authentication for toolsets before tool listing.
589+
590+
For each toolset with auth configured via get_auth_config():
591+
- If credential is available, populate auth_config.exchanged_auth_credential
592+
- If credential is not available, yield auth request event and interrupt
593+
594+
Args:
595+
invocation_context: The invocation context.
596+
agent: The LLM agent.
597+
598+
Yields:
599+
Auth request events if any toolset needs authentication.
600+
"""
601+
if not agent.tools:
602+
return
603+
604+
pending_auth_requests: dict[str, AuthConfig] = {}
605+
callback_context = CallbackContext(invocation_context)
606+
607+
for tool_union in agent.tools:
608+
if not isinstance(tool_union, BaseToolset):
609+
continue
610+
611+
auth_config = tool_union.get_auth_config()
612+
if not auth_config:
613+
continue
614+
615+
try:
616+
credential = await CredentialManager(auth_config).get_auth_credential(
617+
callback_context
618+
)
619+
except ValueError as e:
620+
# Validation errors from CredentialManager should be logged but not
621+
# block the flow - the toolset may still work without auth
622+
logger.warning(
623+
'Failed to get auth credential for toolset %s: %s',
624+
type(tool_union).__name__,
625+
e,
626+
)
627+
credential = None
628+
629+
if credential:
630+
# Populate in-place for toolset to use in get_tools()
631+
auth_config.exchanged_auth_credential = credential
632+
else:
633+
# Need auth - will interrupt
634+
toolset_id = (
635+
f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}'
636+
)
637+
pending_auth_requests[toolset_id] = auth_config
638+
639+
if not pending_auth_requests:
640+
return
641+
642+
# Build auth requests dict with generated auth requests
643+
auth_requests = {
644+
credential_id: AuthHandler(auth_config).generate_auth_request()
645+
for credential_id, auth_config in pending_auth_requests.items()
646+
}
647+
648+
# Yield event with auth requests using the shared helper
649+
yield build_auth_request_event(
650+
invocation_context,
651+
auth_requests,
652+
author=agent.name,
653+
)
654+
655+
# Interrupt invocation
656+
invocation_context.end_invocation = True
657+
564658
async def _postprocess_async(
565659
self,
566660
invocation_context: InvocationContext,

src/google/adk/flows/llm_flows/functions.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Any
2727
from typing import AsyncGenerator
2828
from typing import cast
29+
from typing import Dict
2930
from typing import Optional
3031
from typing import TYPE_CHECKING
3132
import uuid
@@ -34,6 +35,7 @@
3435

3536
from ...agents.active_streaming_tool import ActiveStreamingTool
3637
from ...agents.invocation_context import InvocationContext
38+
from ...auth.auth_tool import AuthConfig
3739
from ...auth.auth_tool import AuthToolArguments
3840
from ...events.event import Event
3941
from ...events.event_actions import EventActions
@@ -211,41 +213,77 @@ def get_long_running_function_calls(
211213
return long_running_tool_ids
212214

213215

214-
def generate_auth_event(
216+
def build_auth_request_event(
215217
invocation_context: InvocationContext,
216-
function_response_event: Event,
217-
) -> Optional[Event]:
218-
if not function_response_event.actions.requested_auth_configs:
219-
return None
218+
auth_requests: Dict[str, AuthConfig],
219+
*,
220+
author: Optional[str] = None,
221+
role: Optional[str] = None,
222+
) -> Event:
223+
"""Builds an auth request event with function calls for each auth request.
224+
225+
This is a shared helper used by both tool-level auth (when a tool requests
226+
auth during execution) and toolset-level auth (before tool listing).
227+
228+
Args:
229+
invocation_context: The invocation context.
230+
auth_requests: Dict mapping function_call_id to AuthConfig.
231+
author: The event author. Defaults to agent name.
232+
role: The content role. Defaults to None.
233+
234+
Returns:
235+
Event with auth request function calls.
236+
"""
220237
parts = []
221238
long_running_tool_ids = set()
222-
for (
223-
function_call_id,
224-
auth_config,
225-
) in function_response_event.actions.requested_auth_configs.items():
226239

240+
for function_call_id, auth_config in auth_requests.items():
227241
request_euc_function_call = types.FunctionCall(
228242
name=REQUEST_EUC_FUNCTION_CALL_NAME,
243+
id=generate_client_function_call_id(),
229244
args=AuthToolArguments(
230245
function_call_id=function_call_id,
231246
auth_config=auth_config,
232247
).model_dump(exclude_none=True, by_alias=True),
233248
)
234-
request_euc_function_call.id = generate_client_function_call_id()
235249
long_running_tool_ids.add(request_euc_function_call.id)
236250
parts.append(types.Part(function_call=request_euc_function_call))
237251

238252
return Event(
239253
invocation_id=invocation_context.invocation_id,
240-
author=invocation_context.agent.name,
254+
author=author or invocation_context.agent.name,
241255
branch=invocation_context.branch,
242-
content=types.Content(
243-
parts=parts, role=function_response_event.content.role
244-
),
256+
content=types.Content(parts=parts, role=role),
245257
long_running_tool_ids=long_running_tool_ids,
246258
)
247259

248260

261+
def generate_auth_event(
262+
invocation_context: InvocationContext,
263+
function_response_event: Event,
264+
) -> Optional[Event]:
265+
"""Generates an auth request event from a function response event.
266+
267+
This is used for tool-level auth where a tool requests credentials during
268+
execution.
269+
270+
Args:
271+
invocation_context: The invocation context.
272+
function_response_event: The function response event with auth requests.
273+
274+
Returns:
275+
Event with auth request function calls, or None if no auth requested.
276+
"""
277+
if not function_response_event.actions.requested_auth_configs:
278+
return None
279+
280+
return build_auth_request_event(
281+
invocation_context,
282+
function_response_event.actions.requested_auth_configs,
283+
role=function_response_event.content.role,
284+
)
285+
286+
249287
def generate_request_confirmation_event(
250288
invocation_context: InvocationContext,
251289
function_call_event: Event,

0 commit comments

Comments
 (0)