Skip to content

Commit a220910

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(auth): Add public api to register custom auth provider with credential manager
PiperOrigin-RevId: 895904372
1 parent bbad9ec commit a220910

File tree

8 files changed

+286
-31
lines changed

8 files changed

+286
-31
lines changed

src/google/adk/auth/auth_provider_registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ def register(
4242
"""
4343
self._providers[auth_scheme_type] = provider_instance
4444

45-
def get_provider(self, auth_scheme: AuthScheme) -> BaseAuthProvider | None:
45+
def get_provider(
46+
self, auth_scheme: AuthScheme | type[AuthScheme]
47+
) -> BaseAuthProvider | None:
4648
"""Get the provider instance for an auth scheme.
4749
4850
Args:
49-
auth_scheme: The auth scheme to get provider for.
51+
auth_scheme: The auth scheme or the auth scheme type to get the provider
52+
for.
5053
5154
Returns:
5255
The provider instance if registered, None otherwise.
5356
"""
57+
if isinstance(auth_scheme, type):
58+
return self._providers.get(auth_scheme)
5459
return self._providers.get(type(auth_scheme))

src/google/adk/auth/auth_schemes.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pydantic import Field
2828

2929
from ..utils.feature_decorator import experimental
30+
from .auth_credential import BaseModelWithConfig
3031

3132

3233
class OpenIdConnectWithConfig(SecurityBase):
@@ -42,8 +43,20 @@ class OpenIdConnectWithConfig(SecurityBase):
4243
scopes: Optional[List[str]] = None
4344

4445

45-
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
46-
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
46+
class CustomAuthScheme(BaseModelWithConfig):
47+
"""A flexible model for custom authentication schemes.
48+
49+
The subclasses must define a `default` for the `type_` field, if using OAuth2
50+
user consent flow, to ensure correct rehydration.
51+
"""
52+
53+
type_: str = Field(alias="type")
54+
55+
56+
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0, an extra flattened
57+
# OpenIdConnectWithConfig, and supports external schemes that subclasses
58+
# CustomAuthScheme
59+
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig, CustomAuthScheme]
4760

4861

4962
class OAuthGrantType(str, Enum):

src/google/adk/auth/auth_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ def get_credential_key(self):
106106
if auth_scheme.model_extra:
107107
auth_scheme = auth_scheme.model_copy(deep=True)
108108
auth_scheme.model_extra.clear()
109+
110+
type_ = auth_scheme.type_
111+
type_name = type_.name if type_ and hasattr(type_, "name") else str(type_)
109112
scheme_name = (
110-
f"{auth_scheme.type_.name}_{_stable_model_digest(auth_scheme)}"
113+
f"{type_name}_{_stable_model_digest(auth_scheme)}"
111114
if auth_scheme
112115
else ""
113116
)

src/google/adk/auth/base_auth_provider.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from abc import ABC
1818
from abc import abstractmethod
19+
from typing import TYPE_CHECKING
20+
21+
if TYPE_CHECKING:
22+
from .auth_schemes import AuthScheme
1923

2024
from ..agents.callback_context import CallbackContext
2125
from ..features import experimental
@@ -28,6 +32,15 @@
2832
class BaseAuthProvider(ABC):
2933
"""Abstract base class for custom authentication providers."""
3034

35+
@property
36+
def supported_auth_schemes(self) -> tuple[type[AuthScheme], ...]:
37+
"""The AuthScheme types supported by this provider.
38+
39+
Subclasses can override this to return a tuple of scheme types, enabling
40+
1-parameter registration.
41+
"""
42+
return ()
43+
3144
@abstractmethod
3245
async def get_auth_credential(
3346
self, auth_config: AuthConfig, context: CallbackContext

src/google/adk/auth/credential_manager.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
from collections.abc import Sequence
1718
import logging
19+
import threading
1820
from typing import Optional
1921

2022
from fastapi.openapi.models import OAuth2
@@ -25,10 +27,13 @@
2527
from .auth_credential import AuthCredential
2628
from .auth_credential import AuthCredentialTypes
2729
from .auth_provider_registry import AuthProviderRegistry
30+
from .auth_schemes import AuthScheme
2831
from .auth_schemes import AuthSchemeType
32+
from .auth_schemes import CustomAuthScheme
2933
from .auth_schemes import ExtendedOAuth2
3034
from .auth_schemes import OpenIdConnectWithConfig
3135
from .auth_tool import AuthConfig
36+
from .base_auth_provider import BaseAuthProvider
3237
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
3338
from .exchanger.base_credential_exchanger import ExchangeResult
3439
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
@@ -38,6 +43,25 @@
3843
logger = logging.getLogger("google_adk." + __name__)
3944

4045

46+
def _rehydrate_custom_scheme(
47+
scheme: CustomAuthScheme, supported_schemes: Sequence[type[AuthScheme]]
48+
) -> CustomAuthScheme:
49+
"""Rehydrate a CustomAuthScheme into one of the given supported_schemes."""
50+
incoming_type = scheme.type_
51+
for scheme_class in supported_schemes:
52+
type_field = scheme_class.model_fields.get("type_")
53+
# Custom AuthScheme classes must define a `default` for their `type_` field
54+
# to be rehydrated correctly.
55+
if type_field and type_field.default == incoming_type:
56+
data = scheme.model_dump(by_alias=True)
57+
if scheme.model_extra:
58+
data.update(scheme.model_extra)
59+
return scheme_class.model_validate(data)
60+
raise ValueError(
61+
f"Cannot rehydrate: no registered scheme matches type '{incoming_type}'"
62+
)
63+
64+
4165
@experimental
4266
class CredentialManager:
4367
"""Manages authentication credentials through a structured workflow.
@@ -77,12 +101,32 @@ class CredentialManager:
77101
```
78102
"""
79103

104+
_auth_provider_registry = AuthProviderRegistry()
105+
_registry_lock = threading.Lock()
106+
107+
@classmethod
108+
def register_auth_provider(cls, provider: BaseAuthProvider) -> None:
109+
"""Public API for developers to register custom auth providers."""
110+
with cls._registry_lock:
111+
for scheme_type in provider.supported_auth_schemes:
112+
existing_provider = cls._auth_provider_registry.get_provider(
113+
scheme_type
114+
)
115+
if existing_provider is not None:
116+
if existing_provider is not provider:
117+
logger.warning(
118+
"An auth provider is already registered for scheme %s. "
119+
"Ignoring the new provider.",
120+
scheme_type,
121+
)
122+
continue
123+
cls._auth_provider_registry.register(scheme_type, provider)
124+
80125
def __init__(
81126
self,
82127
auth_config: AuthConfig,
83128
):
84129
self._auth_config = auth_config
85-
self._auth_provider_registry = AuthProviderRegistry()
86130
self._exchanger_registry = CredentialExchangerRegistry()
87131
self._refresher_registry = CredentialRefresherRegistry()
88132
self._discovery_manager = OAuth2DiscoveryManager()
@@ -139,6 +183,20 @@ async def get_auth_credential(
139183
) -> Optional[AuthCredential]:
140184
"""Load and prepare authentication credential through a structured workflow."""
141185

186+
# Pydantic may have deserialized an unknown scheme into a generic
187+
# CustomAuthScheme. If so, rehydrate it first into a specific subclass.
188+
# Note: Custom authentication scheme classes must have been imported into
189+
# the Python runtime before get_auth_credential is called for their
190+
# subclasses to be registered. This is fine as developer will anyway import
191+
# them while registering the auth providers.
192+
# Note: `__subclasses__()` only returns immediate subclasses, if there is a
193+
# subclass of a subclass of CustomAuthScheme then it will not be returned.
194+
# pylint: disable=unidiomatic-typecheck Needs exact class matching.
195+
if type(self._auth_config.auth_scheme) is CustomAuthScheme:
196+
self._auth_config.auth_scheme = _rehydrate_custom_scheme(
197+
self._auth_config.auth_scheme,
198+
CustomAuthScheme.__subclasses__(),
199+
)
142200
# First, check if a registered auth provider is available before attempting
143201
# to retrieve tokens natively.
144202
provider = self._auth_provider_registry.get_provider(

tests/unittests/auth/test_auth_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.adk.auth.auth_credential import AuthCredential
2424
from google.adk.auth.auth_credential import AuthCredentialTypes
2525
from google.adk.auth.auth_credential import OAuth2Auth
26+
from google.adk.auth.auth_schemes import CustomAuthScheme
2627
from google.adk.auth.auth_tool import AuthConfig
2728
import pytest
2829

@@ -162,3 +163,16 @@ def _run_with_seed(seed: str) -> str:
162163
).strip()
163164

164165
assert _run_with_seed("0") == _run_with_seed("1")
166+
167+
168+
def test_credential_key_with_custom_auth_scheme():
169+
"""Test generating a credential key when the auth scheme is a CustomAuthScheme (type_ is a string)."""
170+
custom_scheme = CustomAuthScheme.model_validate({"type": "mock_custom_type"})
171+
172+
custom_config = AuthConfig(
173+
auth_scheme=custom_scheme,
174+
)
175+
176+
key = custom_config.credential_key
177+
assert key.startswith("adk_mock_custom_type_")
178+
assert len(key) > len("adk_mock_custom_type_")

tests/unittests/auth/test_auth_provider_registry.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
"""Unit tests for the AuthProviderRegistry."""
1616

1717
from google.adk.auth.auth_provider_registry import AuthProviderRegistry
18+
from google.adk.auth.auth_schemes import CustomAuthScheme
1819
from google.adk.auth.base_auth_provider import BaseAuthProvider
19-
from pydantic import BaseModel
20+
from pydantic import Field
2021

2122

22-
class SchemeA(BaseModel):
23-
pass
23+
class SchemeA(CustomAuthScheme):
24+
type_: str = Field(default="scheme_a")
2425

2526

26-
class SchemeB(BaseModel):
27-
pass
27+
class SchemeB(CustomAuthScheme):
28+
type_: str = Field(default="scheme_b")
2829

2930

3031
class TestAuthProviderRegistry:
@@ -42,10 +43,15 @@ def test_register_and_get_provider(self, mocker):
4243
assert registry.get_provider(SchemeA()) is provider_a
4344
assert registry.get_provider(SchemeB()) is provider_b
4445

46+
# Test getting by scheme type
47+
assert registry.get_provider(SchemeA) is provider_a
48+
assert registry.get_provider(SchemeB) is provider_b
49+
4550
def test_get_unregistered_provider_returns_none(self):
4651
"""Test that get_provider returns None for unregistered scheme types."""
4752
registry = AuthProviderRegistry()
4853
assert registry.get_provider(SchemeA()) is None
54+
assert registry.get_provider(SchemeA) is None
4955

5056
def test_register_duplicate_type_overwrites_existing(self, mocker):
5157
"""Test that registering a provider for an existing type overwrites the previous one."""
@@ -57,3 +63,4 @@ def test_register_duplicate_type_overwrites_existing(self, mocker):
5763
registry.register(SchemeA, provider_2)
5864

5965
assert registry.get_provider(SchemeA()) is provider_2
66+
assert registry.get_provider(SchemeA) is provider_2

0 commit comments

Comments
 (0)