1616
1717from __future__ import annotations
1818
19+ import asyncio
20+ import logging
21+ import os
22+ import time
23+
1924from google .adk .agents .callback_context import CallbackContext
2025from google .adk .auth .auth_credential import AuthCredential
26+ from google .adk .auth .auth_credential import AuthCredentialTypes
27+ from google .adk .auth .auth_credential import HttpAuth
28+ from google .adk .auth .auth_credential import HttpCredentials
29+ from google .adk .auth .auth_credential import OAuth2Auth
30+ from google .adk .flows .llm_flows .functions import REQUEST_EUC_FUNCTION_CALL_NAME
31+ from google .api_core .client_options import ClientOptions
32+
33+ try :
34+ from google .cloud .agentidentitycredentials_v1 import AuthProviderCredentialsServiceClient as Client
35+ from google .cloud .agentidentitycredentials_v1 import RetrieveCredentialsRequest
36+ from google .cloud .agentidentitycredentials_v1 import RetrieveCredentialsResponse
37+ except ImportError as e :
38+ raise ImportError (
39+ "Missing required dependencies for Agent Identity Auth Manager. "
40+ 'Please install with: pip install "google-adk[agent-identity]"'
41+ ) from e
2142
2243from .gcp_auth_provider_scheme import GcpAuthProviderScheme
2344
45+ # TODO: Catch specific exceptions instead of generic ones.
46+
47+ logger = logging .getLogger ("google_adk." + __name__ )
48+
49+ NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC : float = 1.0
50+ NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC : float = 10.0
51+
52+
53+ def _construct_auth_credential (
54+ response : RetrieveCredentialsResponse ,
55+ ) -> AuthCredential :
56+ """Constructs a simplified HTTP auth credential from the header-token tuple
57+ returned by the upstream service.
58+ """
59+ if not response .success .header or not response .success .token :
60+ raise ValueError (
61+ "Received either empty header or token from Agent Identity"
62+ " Credentials service."
63+ )
64+
65+ header_name , _ , header_value = response .success .header .partition (":" )
66+ if (
67+ header_name .strip ().lower () == "authorization"
68+ and header_value .strip ().lower ().startswith ("bearer" )
69+ ):
70+ return AuthCredential (
71+ auth_type = AuthCredentialTypes .HTTP ,
72+ http = HttpAuth (
73+ scheme = "Bearer" ,
74+ credentials = HttpCredentials (token = response .success .token ),
75+ ),
76+ )
77+
78+ # Handle custom header.
79+ return AuthCredential (
80+ auth_type = AuthCredentialTypes .HTTP ,
81+ http = HttpAuth (
82+ # For custom headers, scheme and credentials fields are not used.
83+ scheme = "" ,
84+ credentials = HttpCredentials (),
85+ additional_headers = {
86+ response .success .header : response .success .token ,
87+ "X-GOOG-API-KEY" : response .success .token ,
88+ },
89+ ),
90+ )
91+
2492
2593class _AgentIdentityCredentialsProvider :
2694 """Auth provider implementation using Agent Identity credentials service."""
2795
96+ _client : Client | None = None
97+
98+ def __init__ (self , client : Client | None = None ):
99+ self ._client = client
100+
101+ def _get_client (self ) -> Client :
102+ """Lazy loads the client to avoid unnecessary setup on startup."""
103+ if self ._client is None :
104+ client_options = None
105+ if host := os .environ .get ("AGENT_IDENTITY_CREDENTIALS_TARGET_HOST" ):
106+ client_options = ClientOptions (api_endpoint = host )
107+ self ._client = Client (client_options = client_options , transport = "rest" )
108+ return self ._client
109+
110+ async def _retrieve_credentials (
111+ self ,
112+ user_id : str ,
113+ auth_scheme : GcpAuthProviderScheme ,
114+ ) -> RetrieveCredentialsResponse :
115+ request = RetrieveCredentialsRequest (
116+ auth_provider = auth_scheme .name ,
117+ user_id = user_id ,
118+ scopes = auth_scheme .scopes ,
119+ continue_uri = auth_scheme .continue_uri or "" ,
120+ )
121+ # TODO: Use async client once available. Temporarily using threading to
122+ # prevent blocking the event loop.
123+ return await asyncio .to_thread (
124+ self ._get_client ().retrieve_credentials , request
125+ )
126+
127+ async def _poll_credentials (
128+ self , user_id : str , auth_scheme : GcpAuthProviderScheme , timeout : float
129+ ) -> RetrieveCredentialsResponse :
130+ end_time = time .time () + timeout
131+ while time .time () < end_time :
132+ response = await self ._retrieve_credentials (user_id , auth_scheme )
133+ if (
134+ "success" in response
135+ or "uri_consent_required" in response
136+ or "consent_rejected" in response
137+ ):
138+ return response
139+ await asyncio .sleep (NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC )
140+ raise TimeoutError ("Timeout waiting for credentials." )
141+
142+ @staticmethod
143+ def _is_consent_completed (context : CallbackContext ) -> bool :
144+ """Checks if the user consent flow is completed for the current function
145+
146+ call.
147+ """
148+ if not context .function_call_id :
149+ return False
150+
151+ if not context .session :
152+ return False
153+
154+ events = context .session .events
155+ target_tool_call_id = context .function_call_id
156+
157+ # Find all relevant function calls and responses
158+ euc_calls = {}
159+ euc_responses = {}
160+
161+ for event in events :
162+ for call in event .get_function_calls ():
163+ if call .name == REQUEST_EUC_FUNCTION_CALL_NAME :
164+ euc_calls [call .id ] = call
165+ for response in event .get_function_responses ():
166+ if response .name == REQUEST_EUC_FUNCTION_CALL_NAME :
167+ euc_responses [response .id ] = response
168+
169+ # Check for a response that matches a call for the current tool invocation.
170+ for call_id , _ in euc_responses .items ():
171+ if call_id in euc_calls :
172+ call = euc_calls [call_id ]
173+ if call .args and call .args .get ("functionCallId" ) == target_tool_call_id :
174+ return True
175+ return False
176+
28177 async def get_auth_credential (
29178 self ,
30179 auth_scheme : GcpAuthProviderScheme ,
@@ -40,10 +189,60 @@ async def get_auth_credential(
40189 An AuthCredential instance.
41190
42191 Raises:
43- NotImplementedError: Auth provider using Agent Identity Credential service
44- is not yet supported.
192+ RuntimeError: If credential retrieval or polling fails.
45193 """
46- raise NotImplementedError (
47- "Auth provider using Agent Identity Credential service is not yet"
48- " supported."
49- )
194+
195+ if context is None or context .user_id is None :
196+ raise ValueError (
197+ "GcpAuthProvider requires a context with a valid user_id."
198+ )
199+
200+ user_id = context .user_id
201+
202+ try :
203+ response = await self ._retrieve_credentials (user_id , auth_scheme )
204+ except Exception as e :
205+ raise RuntimeError (
206+ f"Failed to retrieve credential for user '{ user_id } ' on"
207+ f" provider '{ auth_scheme .name } '."
208+ ) from e
209+
210+ if "consent_rejected" in response :
211+ raise RuntimeError ("Operation failed: User consent rejected." )
212+
213+ if "success" in response :
214+ logger .debug ("Auth credential obtained immediately." )
215+ return _construct_auth_credential (response )
216+
217+ if "pending" in response :
218+ # Get 2-legged OAuth token. Allow enough time for token exchange.
219+ try :
220+ response = await self ._poll_credentials (
221+ user_id ,
222+ auth_scheme ,
223+ timeout = NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC ,
224+ )
225+ if "consent_rejected" in response :
226+ raise RuntimeError ("Operation failed: User consent rejected." )
227+ if "success" in response :
228+ logger .debug ("Auth credential obtained after polling." )
229+ return _construct_auth_credential (response )
230+ except Exception as e :
231+ raise RuntimeError (
232+ f"Failed to retrieve credential for user '{ user_id } ' on"
233+ f" provider '{ auth_scheme .name } '."
234+ ) from e
235+
236+ if "uri_consent_required" in response :
237+ if self ._is_consent_completed (context ):
238+ raise RuntimeError ("Failed to retrieve consent based credential." )
239+
240+ # Return AuthCredential with only auth_uri to trigger user consent
241+ # flow.
242+ return AuthCredential (
243+ auth_type = AuthCredentialTypes .OAUTH2 ,
244+ oauth2 = OAuth2Auth (
245+ auth_uri = response .uri_consent_required .authorization_uri ,
246+ nonce = response .uri_consent_required .consent_nonce ,
247+ ),
248+ )
0 commit comments