Skip to content

Commit c423fcd

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Move the IamConnectorCredential service depedency to a seperate file
PiperOrigin-RevId: 931088283
1 parent 57bdecf commit c423fcd

6 files changed

Lines changed: 774 additions & 657 deletions

File tree

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
import os
20+
import time
21+
22+
from google.adk.agents.callback_context import CallbackContext
23+
from google.adk.auth.auth_credential import AuthCredential
24+
from google.adk.auth.auth_credential import AuthCredentialTypes
25+
from google.adk.auth.auth_credential import HttpAuth
26+
from google.adk.auth.auth_credential import HttpCredentials
27+
from google.adk.auth.auth_credential import OAuth2Auth
28+
from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
29+
from google.api_core.client_options import ClientOptions
30+
31+
try:
32+
from google.cloud.iamconnectorcredentials_v1alpha import IAMConnectorCredentialsServiceClient as Client
33+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsMetadata
34+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest
35+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse
36+
except ImportError as e:
37+
raise ImportError(
38+
"Missing required dependencies for Agent Identity Auth Manager. "
39+
'Please install with: pip install "google-adk[agent-identity]"'
40+
) from e
41+
from google.longrunning.operations_pb2 import Operation
42+
43+
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
44+
45+
# Notes on the current IAM Connector Credentials service implementation:
46+
# 1. The service does not yet support LROs, so even though the
47+
# retrieve_credentials method returns an Operation object, the methods like
48+
# operation.done() and operation.result() will not work yet.
49+
# 2. For API key flows, the returned Operation contains the credentials.
50+
# 3. For 2-legged OAuth flows, the returned Operation contains pending status,
51+
# client needs to retry the request until response with credentials is
52+
# returned or timeout occurs.
53+
# 4. For 3-legged OAuth flows, the returned Operation contains consent pending
54+
# status along with the authorization URI.
55+
56+
# TODO: Catch specific exceptions instead of generic ones.
57+
58+
logger = logging.getLogger("google_adk." + __name__)
59+
60+
NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC: float = 1.0
61+
NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC: float = 10.0
62+
63+
64+
def _construct_auth_credential(
65+
response: RetrieveCredentialsResponse,
66+
) -> AuthCredential:
67+
"""Constructs a simplified HTTP auth credential from the header-token tuple returned by the upstream service."""
68+
if not response.header or not response.token:
69+
raise ValueError(
70+
"Received either empty header or token from IAM Connector Credentials"
71+
" service."
72+
)
73+
74+
header_name, _, header_value = response.header.partition(":")
75+
if (
76+
header_name.strip().lower() == "authorization"
77+
and header_value.strip().lower().startswith("bearer")
78+
):
79+
return AuthCredential(
80+
auth_type=AuthCredentialTypes.HTTP,
81+
http=HttpAuth(
82+
scheme="Bearer",
83+
credentials=HttpCredentials(token=response.token),
84+
),
85+
)
86+
87+
# Handle custom header.
88+
return AuthCredential(
89+
auth_type=AuthCredentialTypes.HTTP,
90+
http=HttpAuth(
91+
# For custom headers, scheme and credentials fields are not used.
92+
scheme="",
93+
credentials=HttpCredentials(),
94+
additional_headers={
95+
response.header: response.token,
96+
"X-GOOG-API-KEY": response.token,
97+
},
98+
),
99+
)
100+
101+
102+
class _IamConnectorCredentialsProvider:
103+
"""Implementation for auth provider using IAM Connector credentials service."""
104+
105+
_client: Client | None = None
106+
107+
def __init__(self, client: Client | None = None):
108+
self._client = client
109+
110+
def _get_client(self) -> Client:
111+
"""Lazy loads the client to avoid unnecessary setup on startup."""
112+
if self._client is None:
113+
client_options = None
114+
if host := os.environ.get("IAM_CONNECTOR_CREDENTIALS_TARGET_HOST"):
115+
client_options = ClientOptions(api_endpoint=host)
116+
self._client = Client(client_options=client_options, transport="rest")
117+
return self._client
118+
119+
async def _retrieve_credentials(
120+
self,
121+
user_id: str,
122+
auth_scheme: GcpAuthProviderScheme,
123+
) -> Operation:
124+
request = RetrieveCredentialsRequest(
125+
connector=auth_scheme.name,
126+
user_id=user_id,
127+
scopes=auth_scheme.scopes,
128+
continue_uri=auth_scheme.continue_uri or "",
129+
force_refresh=False,
130+
)
131+
# TODO: Use async client once available. Temporarily using threading to
132+
# prevent blocking the event loop.
133+
operation = await asyncio.to_thread(
134+
self._get_client().retrieve_credentials, request
135+
)
136+
return operation.operation
137+
138+
def _unpack_operation(
139+
self, operation: Operation
140+
) -> tuple[
141+
RetrieveCredentialsResponse | None, RetrieveCredentialsMetadata | None
142+
]:
143+
"""Deserializes the response and metadata from the operation."""
144+
response = None
145+
metadata = None
146+
if operation.response:
147+
response = RetrieveCredentialsResponse.deserialize(
148+
operation.response.value
149+
)
150+
if operation.metadata:
151+
metadata = RetrieveCredentialsMetadata.deserialize(
152+
operation.metadata.value
153+
)
154+
return response, metadata
155+
156+
async def _poll_credentials(
157+
self, user_id: str, auth_scheme: GcpAuthProviderScheme, timeout: float
158+
) -> Operation:
159+
end_time = time.time() + timeout
160+
while time.time() < end_time:
161+
operation = await self._retrieve_credentials(user_id, auth_scheme)
162+
if operation.done:
163+
return operation
164+
await asyncio.sleep(NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC)
165+
raise TimeoutError("Timeout waiting for credentials.")
166+
167+
@staticmethod
168+
def _is_consent_completed(context: CallbackContext) -> bool:
169+
"""Checks if the user consent flow is completed for the current function call."""
170+
if not context.function_call_id:
171+
return False
172+
173+
if not context.session:
174+
return False
175+
176+
events = context.session.events
177+
target_tool_call_id = context.function_call_id
178+
179+
# Find all relevant function calls and responses
180+
euc_calls = {}
181+
euc_responses = {}
182+
183+
for event in events:
184+
for call in event.get_function_calls():
185+
if call.name == REQUEST_EUC_FUNCTION_CALL_NAME:
186+
euc_calls[call.id] = call
187+
for response in event.get_function_responses():
188+
if response.name == REQUEST_EUC_FUNCTION_CALL_NAME:
189+
euc_responses[response.id] = response
190+
191+
# Check for a response that matches a call for the current tool invocation
192+
for call_id, _ in euc_responses.items():
193+
if call_id in euc_calls:
194+
call = euc_calls[call_id]
195+
if call.args and call.args.get("functionCallId") == target_tool_call_id:
196+
return True
197+
return False
198+
199+
async def get_auth_credential(
200+
self,
201+
auth_scheme: GcpAuthProviderScheme,
202+
context: CallbackContext | None = None,
203+
) -> AuthCredential:
204+
"""Retrieves credentials using the IAM Connector Credentials service.
205+
206+
Args:
207+
auth_scheme: The GcpAuthProviderScheme.
208+
context: Optional context for the callback.
209+
210+
Returns:
211+
An AuthCredential instance.
212+
213+
Raises:
214+
RuntimeError: If credential retrieval or polling fails.
215+
"""
216+
217+
if context is None or context.user_id is None:
218+
raise ValueError(
219+
"GcpAuthProvider requires a context with a valid user_id."
220+
)
221+
222+
user_id = context.user_id
223+
224+
try:
225+
operation = await self._retrieve_credentials(user_id, auth_scheme)
226+
except Exception as e:
227+
raise RuntimeError(
228+
f"Failed to retrieve credential for user '{user_id}' on connector"
229+
f" '{auth_scheme.name}'."
230+
) from e
231+
232+
response, metadata = self._unpack_operation(operation)
233+
234+
if operation.HasField("error"):
235+
raise RuntimeError(f"Operation failed: {operation.error.message}")
236+
237+
if operation.done:
238+
logger.debug("Auth credential obtained immediately.")
239+
return _construct_auth_credential(response)
240+
241+
if metadata and metadata.consent_pending:
242+
# Get 2-legged OAuth token. Allow enough time for token exchange.
243+
try:
244+
operation = await self._poll_credentials(
245+
user_id,
246+
auth_scheme,
247+
timeout=NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC,
248+
)
249+
if operation.HasField("error"):
250+
raise RuntimeError(f"Operation failed: {operation.error.message}")
251+
if operation.done:
252+
logger.debug("Auth credential obtained after polling.")
253+
response, _ = self._unpack_operation(operation)
254+
return _construct_auth_credential(response)
255+
except Exception as e:
256+
raise RuntimeError(
257+
f"Failed to retrieve credential for user '{user_id}' on connector"
258+
f" '{auth_scheme.name}'."
259+
) from e
260+
261+
if metadata is not None and metadata.uri_consent_required:
262+
if self._is_consent_completed(context):
263+
raise RuntimeError("Failed to retrieve consent based credential.")
264+
265+
# Return AuthCredential with only auth_uri to trigger user consent flow.
266+
return AuthCredential(
267+
auth_type=AuthCredentialTypes.OAUTH2,
268+
oauth2=OAuth2Auth(
269+
auth_uri=metadata.uri_consent_required.authorization_uri,
270+
nonce=metadata.uri_consent_required.consent_nonce,
271+
),
272+
)

0 commit comments

Comments
 (0)