Skip to content

Commit e7316dc

Browse files
wukathcopybara-github
authored andcommitted
feat: Support OAuth PKCE in McpToolset
Generates a random code verifier and code challenge during the initial auth request. During token exchange, the client sends the original code verifier so the server can verify it matches the previously sent challenge. This prevents attacks by ensuring that only the client that initiated the request can obtain the final access token. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 912703679
1 parent 88421f8 commit e7316dc

6 files changed

Lines changed: 165 additions & 32 deletions

File tree

src/google/adk/auth/auth_credential.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Dict
2020
from typing import List
2121
from typing import Literal
22-
from typing import Optional
2322

2423
from pydantic import alias_generators
2524
from pydantic import BaseModel
@@ -40,9 +39,9 @@ class BaseModelWithConfig(BaseModel):
4039
class HttpCredentials(BaseModelWithConfig):
4140
"""Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc."""
4241

43-
username: Optional[str] = None
44-
password: Optional[str] = None
45-
token: Optional[str] = None
42+
username: str | None = None
43+
password: str | None = None
44+
token: str | None = None
4645

4746
@classmethod
4847
def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials":
@@ -62,40 +61,43 @@ class HttpAuth(BaseModelWithConfig):
6261
# Examples: 'basic', 'bearer'
6362
scheme: str
6463
credentials: HttpCredentials
65-
additional_headers: Optional[Dict[str, str]] = None
64+
additional_headers: Dict[str, str] | None = None
6665

6766

6867
class OAuth2Auth(BaseModelWithConfig):
6968
"""Represents credential value and its metadata for a OAuth2 credential."""
7069

71-
client_id: Optional[str] = None
72-
client_secret: Optional[str] = None
70+
client_id: str | None = None
71+
client_secret: str | None = None
7372
# tool or adk can generate the auth_uri with the state info thus client
7473
# can verify the state
75-
auth_uri: Optional[str] = None
74+
auth_uri: str | None = None
7675
# A unique value generated at the start of the OAuth flow to bind the user's
7776
# session to the authorization request. This value is typically stored with
7877
# user session and passed to backend for validation.
79-
nonce: Optional[str] = None
80-
state: Optional[str] = None
78+
nonce: str | None = None
79+
state: str | None = None
8180
# tool or adk can decide the redirect_uri if they don't want client to decide
82-
redirect_uri: Optional[str] = None
83-
auth_response_uri: Optional[str] = None
84-
auth_code: Optional[str] = None
85-
access_token: Optional[str] = None
86-
refresh_token: Optional[str] = None
87-
id_token: Optional[str] = None
88-
expires_at: Optional[int] = None
89-
expires_in: Optional[int] = None
90-
audience: Optional[str] = None
91-
token_endpoint_auth_method: Optional[
81+
redirect_uri: str | None = None
82+
auth_response_uri: str | None = None
83+
auth_code: str | None = None
84+
access_token: str | None = None
85+
refresh_token: str | None = None
86+
id_token: str | None = None
87+
expires_at: int | None = None
88+
expires_in: int | None = None
89+
audience: str | None = None
90+
code_verifier: str | None = None
91+
code_challenge_method: str | None = None
92+
token_endpoint_auth_method: (
9293
Literal[
9394
"client_secret_basic",
9495
"client_secret_post",
9596
"client_secret_jwt",
9697
"private_key_jwt",
9798
]
98-
] = "client_secret_basic"
99+
| None
100+
) = "client_secret_basic"
99101

100102

101103
class ServiceAccountCredential(BaseModelWithConfig):
@@ -166,11 +168,11 @@ class ServiceAccount(BaseModelWithConfig):
166168
when ``use_id_token`` is True.
167169
"""
168170

169-
service_account_credential: Optional[ServiceAccountCredential] = None
170-
scopes: Optional[List[str]] = None
171-
use_default_credential: Optional[bool] = False
172-
use_id_token: Optional[bool] = False
173-
audience: Optional[str] = None
171+
service_account_credential: ServiceAccountCredential | None = None
172+
scopes: List[str] | None = None
173+
use_default_credential: bool | None = False
174+
use_id_token: bool | None = False
175+
audience: str | None = None
174176

175177
@model_validator(mode="after")
176178
def _validate_config(self) -> ServiceAccount:
@@ -275,9 +277,9 @@ class AuthCredential(BaseModelWithConfig):
275277
auth_type: AuthCredentialTypes
276278
# Resource reference for the credential.
277279
# This will be supported in the future.
278-
resource_ref: Optional[str] = None
280+
resource_ref: str | None = None
279281

280-
api_key: Optional[str] = None
281-
http: Optional[HttpAuth] = None
282-
service_account: Optional[ServiceAccount] = None
283-
oauth2: Optional[OAuth2Auth] = None
282+
api_key: str | None = None
283+
http: HttpAuth | None = None
284+
service_account: ServiceAccount | None = None
285+
oauth2: OAuth2Auth | None = None

src/google/adk/auth/auth_handler.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..sessions.state import State
2929

3030
try:
31+
from authlib.common.security import generate_token
3132
from authlib.integrations.requests_client import OAuth2Session
3233

3334
AUTHLIB_AVAILABLE = True
@@ -158,6 +159,8 @@ def generate_auth_uri(
158159

159160
auth_scheme = self.auth_config.auth_scheme
160161
auth_credential = self.auth_config.raw_auth_credential
162+
if not auth_credential or not auth_credential.oauth2:
163+
raise ValueError("raw_auth_credential or oauth2 is empty")
161164

162165
if isinstance(auth_scheme, OpenIdConnectWithConfig):
163166
authorization_endpoint = auth_scheme.authorization_endpoint
@@ -190,19 +193,38 @@ def generate_auth_uri(
190193
auth_credential.oauth2.client_secret,
191194
scope=" ".join(scopes),
192195
redirect_uri=auth_credential.oauth2.redirect_uri,
196+
code_challenge_method=auth_credential.oauth2.code_challenge_method,
193197
)
194198
params = {
195199
"access_type": "offline",
196200
"prompt": "consent",
197201
}
198202
if auth_credential.oauth2.audience:
199203
params["audience"] = auth_credential.oauth2.audience
204+
205+
# If using PKCE with S256, ensure a code_verifier exists.
206+
# If not provided in the credential, generate a cryptographically secure
207+
# random token of 48 characters (OAuth2 recommends 43-128 characters).
208+
code_verifier = auth_credential.oauth2.code_verifier
209+
method = auth_credential.oauth2.code_challenge_method
210+
211+
if method:
212+
if method != "S256":
213+
raise ValueError(
214+
f"Unsupported code_challenge_method: {method}. Only 'S256' is"
215+
" supported."
216+
)
217+
if not code_verifier:
218+
code_verifier = generate_token(48)
219+
200220
uri, state = client.create_authorization_url(
201-
url=authorization_endpoint, **params
221+
url=authorization_endpoint, code_verifier=code_verifier, **params
202222
)
203223

204224
exchanged_auth_credential = auth_credential.model_copy(deep=True)
205225
exchanged_auth_credential.oauth2.auth_uri = uri
206226
exchanged_auth_credential.oauth2.state = state
227+
if code_verifier:
228+
exchanged_auth_credential.oauth2.code_verifier = code_verifier
207229

208230
return exchanged_auth_credential

src/google/adk/auth/exchanger/oauth2_credential_exchanger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ async def _exchange_authorization_code(
193193
return ExchangeResult(auth_credential, False)
194194

195195
try:
196+
kwargs = {}
197+
# If a code_verifier is available (e.g. from PKCE), include it in the
198+
# token exchange request.
199+
if auth_credential.oauth2 and auth_credential.oauth2.code_verifier:
200+
kwargs["code_verifier"] = auth_credential.oauth2.code_verifier
201+
196202
# Authlib already injects client_id for body-based client auth flows such
197203
# as client_secret_post, so passing it here would duplicate the field.
198204
tokens = client.fetch_token(
@@ -202,6 +208,7 @@ async def _exchange_authorization_code(
202208
),
203209
code=auth_credential.oauth2.auth_code,
204210
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
211+
**kwargs,
205212
)
206213
update_credential_with_tokens(auth_credential, tokens)
207214
logger.debug("Successfully exchanged authorization code for access token")

src/google/adk/auth/oauth2_credential_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def create_oauth2_session(
9292
redirect_uri=auth_credential.oauth2.redirect_uri,
9393
state=auth_credential.oauth2.state,
9494
token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method,
95+
code_challenge_method=auth_credential.oauth2.code_challenge_method,
9596
),
9697
token_endpoint,
9798
)

tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,57 @@ async def test_exchange_success(self, mock_oauth2_session):
135135
assert exchange_result.was_exchanged
136136
mock_client.fetch_token.assert_called_once()
137137

138+
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
139+
async def test_exchange_success_pkce(self, mock_oauth2_session):
140+
"""Test successful token exchange with PKCE."""
141+
# Setup mock
142+
mock_client = Mock()
143+
mock_oauth2_session.return_value = mock_client
144+
mock_tokens = OAuth2Token({
145+
"access_token": "new_access_token",
146+
"refresh_token": "new_refresh_token",
147+
"expires_at": int(time.time()) + 3600,
148+
"expires_in": 3600,
149+
})
150+
mock_client.fetch_token.return_value = mock_tokens
151+
152+
scheme = OpenIdConnectWithConfig(
153+
type_="openIdConnect",
154+
openId_connect_url=(
155+
"https://example.com/.well-known/openid_configuration"
156+
),
157+
authorization_endpoint="https://example.com/auth",
158+
token_endpoint="https://example.com/token",
159+
scopes=["openid"],
160+
)
161+
credential = AuthCredential(
162+
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
163+
oauth2=OAuth2Auth(
164+
client_id="test_client_id",
165+
client_secret="test_client_secret",
166+
auth_response_uri="https://example.com/callback?code=auth_code",
167+
auth_code="auth_code",
168+
code_verifier="mock_code_verifier",
169+
),
170+
)
171+
172+
exchanger = OAuth2CredentialExchanger()
173+
exchange_result = await exchanger.exchange(credential, scheme)
174+
175+
# Verify token exchange was successful
176+
assert exchange_result.credential.oauth2.access_token == "new_access_token"
177+
assert (
178+
exchange_result.credential.oauth2.refresh_token == "new_refresh_token"
179+
)
180+
assert exchange_result.was_exchanged
181+
mock_client.fetch_token.assert_called_once_with(
182+
"https://example.com/token",
183+
authorization_response="https://example.com/callback?code=auth_code",
184+
code="auth_code",
185+
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
186+
code_verifier="mock_code_verifier",
187+
)
188+
138189
async def test_exchange_missing_auth_scheme(self):
139190
"""Test exchange with missing auth_scheme raises ValueError."""
140191
credential = AuthCredential(

tests/unittests/auth/test_auth_handler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ def __init__(
5353
scope=None,
5454
redirect_uri=None,
5555
state=None,
56+
**kwargs,
5657
):
5758
self.client_id = client_id
5859
self.client_secret = client_secret
5960
self.scope = scope
6061
self.redirect_uri = redirect_uri
6162
self.state = state
63+
self.extra_kwargs = kwargs
6264

6365
def create_authorization_url(self, url, **kwargs):
6466
params = f"client_id={self.client_id}&scope={self.scope}"
@@ -271,6 +273,54 @@ def test_generate_auth_uri_openid(
271273
assert "client_id=mock_client_id" in result.oauth2.auth_uri
272274
assert result.oauth2.state == "mock_state"
273275

276+
@patch("google.adk.auth.auth_handler.OAuth2Session")
277+
def test_generate_auth_uri_pkce(
278+
self, mock_oauth2_session, oauth2_auth_scheme, oauth2_credentials
279+
):
280+
"""Test generating an auth URI with PKCE."""
281+
oauth2_credentials.oauth2.code_challenge_method = "S256"
282+
exchanged = oauth2_credentials.model_copy(deep=True)
283+
284+
config = AuthConfig(
285+
auth_scheme=oauth2_auth_scheme,
286+
raw_auth_credential=oauth2_credentials,
287+
exchanged_auth_credential=exchanged,
288+
)
289+
290+
mock_client = Mock()
291+
mock_oauth2_session.return_value = mock_client
292+
mock_client.create_authorization_url.return_value = (
293+
"https://example.com/oauth2/authorize?code_challenge=...&code_challenge_method=S256",
294+
"mock_state",
295+
)
296+
297+
handler = AuthHandler(config)
298+
result = handler.generate_auth_uri()
299+
300+
assert result.oauth2.code_verifier is not None
301+
assert len(result.oauth2.code_verifier) == 48
302+
mock_client.create_authorization_url.assert_called_once()
303+
_, kwargs = mock_client.create_authorization_url.call_args
304+
assert "code_verifier" in kwargs
305+
assert kwargs["code_verifier"] == result.oauth2.code_verifier
306+
307+
def test_generate_auth_uri_unsupported_pkce_method(
308+
self, oauth2_auth_scheme, oauth2_credentials
309+
):
310+
"""Test generating an auth URI with unsupported PKCE method."""
311+
oauth2_credentials.oauth2.code_challenge_method = "plain"
312+
exchanged = oauth2_credentials.model_copy(deep=True)
313+
314+
config = AuthConfig(
315+
auth_scheme=oauth2_auth_scheme,
316+
raw_auth_credential=oauth2_credentials,
317+
exchanged_auth_credential=exchanged,
318+
)
319+
320+
handler = AuthHandler(config)
321+
with pytest.raises(ValueError, match="Unsupported code_challenge_method"):
322+
handler.generate_auth_uri()
323+
274324

275325
class TestGenerateAuthRequest:
276326
"""Tests for the generate_auth_request method."""

0 commit comments

Comments
 (0)