|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +from collections.abc import Sequence |
17 | 18 | import logging |
| 19 | +import threading |
18 | 20 | from typing import Optional |
19 | 21 |
|
20 | 22 | from fastapi.openapi.models import OAuth2 |
|
25 | 27 | from .auth_credential import AuthCredential |
26 | 28 | from .auth_credential import AuthCredentialTypes |
27 | 29 | from .auth_provider_registry import AuthProviderRegistry |
| 30 | +from .auth_schemes import AuthScheme |
28 | 31 | from .auth_schemes import AuthSchemeType |
| 32 | +from .auth_schemes import CustomAuthScheme |
29 | 33 | from .auth_schemes import ExtendedOAuth2 |
30 | 34 | from .auth_schemes import OpenIdConnectWithConfig |
31 | 35 | from .auth_tool import AuthConfig |
| 36 | +from .base_auth_provider import BaseAuthProvider |
32 | 37 | from .exchanger.base_credential_exchanger import BaseCredentialExchanger |
33 | 38 | from .exchanger.base_credential_exchanger import ExchangeResult |
34 | 39 | from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry |
|
38 | 43 | logger = logging.getLogger("google_adk." + __name__) |
39 | 44 |
|
40 | 45 |
|
| 46 | +def _rehydrate_custom_scheme( |
| 47 | + scheme: CustomAuthScheme, supported_schemes: Sequence[type[AuthScheme]] |
| 48 | +) -> CustomAuthScheme: |
| 49 | + """Rehydrate a CustomAuthScheme into one of the given supported_schemes.""" |
| 50 | + incoming_type = scheme.type_ |
| 51 | + for scheme_class in supported_schemes: |
| 52 | + type_field = scheme_class.model_fields.get("type_") |
| 53 | + # Custom AuthScheme classes must define a `default` for their `type_` field |
| 54 | + # to be rehydrated correctly. |
| 55 | + if type_field and type_field.default == incoming_type: |
| 56 | + data = scheme.model_dump(by_alias=True) |
| 57 | + if scheme.model_extra: |
| 58 | + data.update(scheme.model_extra) |
| 59 | + return scheme_class.model_validate(data) |
| 60 | + raise ValueError( |
| 61 | + f"Cannot rehydrate: no registered scheme matches type '{incoming_type}'" |
| 62 | + ) |
| 63 | + |
| 64 | + |
41 | 65 | @experimental |
42 | 66 | class CredentialManager: |
43 | 67 | """Manages authentication credentials through a structured workflow. |
@@ -77,12 +101,32 @@ class CredentialManager: |
77 | 101 | ``` |
78 | 102 | """ |
79 | 103 |
|
| 104 | + _auth_provider_registry = AuthProviderRegistry() |
| 105 | + _registry_lock = threading.Lock() |
| 106 | + |
| 107 | + @classmethod |
| 108 | + def register_auth_provider(cls, provider: BaseAuthProvider) -> None: |
| 109 | + """Public API for developers to register custom auth providers.""" |
| 110 | + with cls._registry_lock: |
| 111 | + for scheme_type in provider.supported_auth_schemes: |
| 112 | + existing_provider = cls._auth_provider_registry.get_provider( |
| 113 | + scheme_type |
| 114 | + ) |
| 115 | + if existing_provider is not None: |
| 116 | + if existing_provider is not provider: |
| 117 | + logger.warning( |
| 118 | + "An auth provider is already registered for scheme %s. " |
| 119 | + "Ignoring the new provider.", |
| 120 | + scheme_type, |
| 121 | + ) |
| 122 | + continue |
| 123 | + cls._auth_provider_registry.register(scheme_type, provider) |
| 124 | + |
80 | 125 | def __init__( |
81 | 126 | self, |
82 | 127 | auth_config: AuthConfig, |
83 | 128 | ): |
84 | 129 | self._auth_config = auth_config |
85 | | - self._auth_provider_registry = AuthProviderRegistry() |
86 | 130 | self._exchanger_registry = CredentialExchangerRegistry() |
87 | 131 | self._refresher_registry = CredentialRefresherRegistry() |
88 | 132 | self._discovery_manager = OAuth2DiscoveryManager() |
@@ -139,6 +183,20 @@ async def get_auth_credential( |
139 | 183 | ) -> Optional[AuthCredential]: |
140 | 184 | """Load and prepare authentication credential through a structured workflow.""" |
141 | 185 |
|
| 186 | + # Pydantic may have deserialized an unknown scheme into a generic |
| 187 | + # CustomAuthScheme. If so, rehydrate it first into a specific subclass. |
| 188 | + # Note: Custom authentication scheme classes must have been imported into |
| 189 | + # the Python runtime before get_auth_credential is called for their |
| 190 | + # subclasses to be registered. This is fine as developer will anyway import |
| 191 | + # them while registering the auth providers. |
| 192 | + # Note: `__subclasses__()` only returns immediate subclasses, if there is a |
| 193 | + # subclass of a subclass of CustomAuthScheme then it will not be returned. |
| 194 | + # pylint: disable=unidiomatic-typecheck Needs exact class matching. |
| 195 | + if type(self._auth_config.auth_scheme) is CustomAuthScheme: |
| 196 | + self._auth_config.auth_scheme = _rehydrate_custom_scheme( |
| 197 | + self._auth_config.auth_scheme, |
| 198 | + CustomAuthScheme.__subclasses__(), |
| 199 | + ) |
142 | 200 | # First, check if a registered auth provider is available before attempting |
143 | 201 | # to retrieve tokens natively. |
144 | 202 | provider = self._auth_provider_registry.get_provider( |
|
0 commit comments