Skip to content

Commit dc6fbd8

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Implement the auth provider using Agent Identity Credentials service
PiperOrigin-RevId: 931503059
1 parent 46b9dc3 commit dc6fbd8

4 files changed

Lines changed: 1154 additions & 29 deletions

File tree

src/google/adk/integrations/agent_identity/_agent_identity_credentials_provider.py

Lines changed: 205 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,164 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
20+
import logging
21+
import os
22+
import time
23+
1924
from google.adk.agents.callback_context import CallbackContext
2025
from google.adk.auth.auth_credential import AuthCredential
26+
from google.adk.auth.auth_credential import AuthCredentialTypes
27+
from google.adk.auth.auth_credential import HttpAuth
28+
from google.adk.auth.auth_credential import HttpCredentials
29+
from google.adk.auth.auth_credential import OAuth2Auth
30+
from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
31+
from google.api_core.client_options import ClientOptions
32+
33+
try:
34+
from google.cloud.agentidentitycredentials_v1 import AuthProviderCredentialsServiceClient as Client
35+
from google.cloud.agentidentitycredentials_v1 import RetrieveCredentialsRequest
36+
from google.cloud.agentidentitycredentials_v1 import RetrieveCredentialsResponse
37+
except ImportError as e:
38+
raise ImportError(
39+
"Missing required dependencies for Agent Identity Auth Manager. "
40+
'Please install with: pip install "google-adk[agent-identity]"'
41+
) from e
2142

2243
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
2344

45+
# TODO: Catch specific exceptions instead of generic ones.
46+
47+
logger = logging.getLogger("google_adk." + __name__)
48+
49+
NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC: float = 1.0
50+
NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC: float = 10.0
51+
52+
53+
def _construct_auth_credential(
54+
response: RetrieveCredentialsResponse,
55+
) -> AuthCredential:
56+
"""Constructs a simplified HTTP auth credential from the header-token tuple
57+
returned by the upstream service.
58+
"""
59+
if not response.success.header or not response.success.token:
60+
raise ValueError(
61+
"Received either empty header or token from Agent Identity"
62+
" Credentials service."
63+
)
64+
65+
header_name, _, header_value = response.success.header.partition(":")
66+
if (
67+
header_name.strip().lower() == "authorization"
68+
and header_value.strip().lower().startswith("bearer")
69+
):
70+
return AuthCredential(
71+
auth_type=AuthCredentialTypes.HTTP,
72+
http=HttpAuth(
73+
scheme="Bearer",
74+
credentials=HttpCredentials(token=response.success.token),
75+
),
76+
)
77+
78+
# Handle custom header.
79+
return AuthCredential(
80+
auth_type=AuthCredentialTypes.HTTP,
81+
http=HttpAuth(
82+
# For custom headers, scheme and credentials fields are not used.
83+
scheme="",
84+
credentials=HttpCredentials(),
85+
additional_headers={
86+
response.success.header: response.success.token,
87+
"X-GOOG-API-KEY": response.success.token,
88+
},
89+
),
90+
)
91+
2492

2593
class _AgentIdentityCredentialsProvider:
2694
"""Auth provider implementation using Agent Identity credentials service."""
2795

96+
_client: Client | None = None
97+
98+
def __init__(self, client: Client | None = None):
99+
self._client = client
100+
101+
def _get_client(self) -> Client:
102+
"""Lazy loads the client to avoid unnecessary setup on startup."""
103+
if self._client is None:
104+
client_options = None
105+
if host := os.environ.get("AGENT_IDENTITY_CREDENTIALS_TARGET_HOST"):
106+
client_options = ClientOptions(api_endpoint=host)
107+
self._client = Client(client_options=client_options, transport="rest")
108+
return self._client
109+
110+
async def _retrieve_credentials(
111+
self,
112+
user_id: str,
113+
auth_scheme: GcpAuthProviderScheme,
114+
) -> RetrieveCredentialsResponse:
115+
request = RetrieveCredentialsRequest(
116+
auth_provider=auth_scheme.name,
117+
user_id=user_id,
118+
scopes=auth_scheme.scopes,
119+
continue_uri=auth_scheme.continue_uri or "",
120+
)
121+
# TODO: Use async client once available. Temporarily using threading to
122+
# prevent blocking the event loop.
123+
return await asyncio.to_thread(
124+
self._get_client().retrieve_credentials, request
125+
)
126+
127+
async def _poll_credentials(
128+
self, user_id: str, auth_scheme: GcpAuthProviderScheme, timeout: float
129+
) -> RetrieveCredentialsResponse:
130+
end_time = time.time() + timeout
131+
while time.time() < end_time:
132+
response = await self._retrieve_credentials(user_id, auth_scheme)
133+
if (
134+
"success" in response
135+
or "uri_consent_required" in response
136+
or "consent_rejected" in response
137+
):
138+
return response
139+
await asyncio.sleep(NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC)
140+
raise TimeoutError("Timeout waiting for credentials.")
141+
142+
@staticmethod
143+
def _is_consent_completed(context: CallbackContext) -> bool:
144+
"""Checks if the user consent flow is completed for the current function
145+
146+
call.
147+
"""
148+
if not context.function_call_id:
149+
return False
150+
151+
if not context.session:
152+
return False
153+
154+
events = context.session.events
155+
target_tool_call_id = context.function_call_id
156+
157+
# Find all relevant function calls and responses
158+
euc_calls = {}
159+
euc_responses = {}
160+
161+
for event in events:
162+
for call in event.get_function_calls():
163+
if call.name == REQUEST_EUC_FUNCTION_CALL_NAME:
164+
euc_calls[call.id] = call
165+
for response in event.get_function_responses():
166+
if response.name == REQUEST_EUC_FUNCTION_CALL_NAME:
167+
euc_responses[response.id] = response
168+
169+
# Check for a response that matches a call for the current tool invocation.
170+
for call_id, _ in euc_responses.items():
171+
if call_id in euc_calls:
172+
call = euc_calls[call_id]
173+
if call.args and call.args.get("functionCallId") == target_tool_call_id:
174+
return True
175+
return False
176+
28177
async def get_auth_credential(
29178
self,
30179
auth_scheme: GcpAuthProviderScheme,
@@ -40,10 +189,60 @@ async def get_auth_credential(
40189
An AuthCredential instance.
41190
42191
Raises:
43-
NotImplementedError: Auth provider using Agent Identity Credential service
44-
is not yet supported.
192+
RuntimeError: If credential retrieval or polling fails.
45193
"""
46-
raise NotImplementedError(
47-
"Auth provider using Agent Identity Credential service is not yet"
48-
" supported."
49-
)
194+
195+
if context is None or context.user_id is None:
196+
raise ValueError(
197+
"GcpAuthProvider requires a context with a valid user_id."
198+
)
199+
200+
user_id = context.user_id
201+
202+
try:
203+
response = await self._retrieve_credentials(user_id, auth_scheme)
204+
except Exception as e:
205+
raise RuntimeError(
206+
f"Failed to retrieve credential for user '{user_id}' on"
207+
f" provider '{auth_scheme.name}'."
208+
) from e
209+
210+
if "consent_rejected" in response:
211+
raise RuntimeError("Operation failed: User consent rejected.")
212+
213+
if "success" in response:
214+
logger.debug("Auth credential obtained immediately.")
215+
return _construct_auth_credential(response)
216+
217+
if "pending" in response:
218+
# Get 2-legged OAuth token. Allow enough time for token exchange.
219+
try:
220+
response = await self._poll_credentials(
221+
user_id,
222+
auth_scheme,
223+
timeout=NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC,
224+
)
225+
if "consent_rejected" in response:
226+
raise RuntimeError("Operation failed: User consent rejected.")
227+
if "success" in response:
228+
logger.debug("Auth credential obtained after polling.")
229+
return _construct_auth_credential(response)
230+
except Exception as e:
231+
raise RuntimeError(
232+
f"Failed to retrieve credential for user '{user_id}' on"
233+
f" provider '{auth_scheme.name}'."
234+
) from e
235+
236+
if "uri_consent_required" in response:
237+
if self._is_consent_completed(context):
238+
raise RuntimeError("Failed to retrieve consent based credential.")
239+
240+
# Return AuthCredential with only auth_uri to trigger user consent
241+
# flow.
242+
return AuthCredential(
243+
auth_type=AuthCredentialTypes.OAUTH2,
244+
oauth2=OAuth2Auth(
245+
auth_uri=response.uri_consent_required.authorization_uri,
246+
nonce=response.uri_consent_required.consent_nonce,
247+
),
248+
)

0 commit comments

Comments
 (0)