|
10 | 10 | from cryptojwt.jwt import utc_time_sans_frac |
11 | 11 |
|
12 | 12 | from idpyoidc import claims |
| 13 | +from idpyoidc.util import importer |
13 | 14 | from idpyoidc.message import Message |
14 | 15 | from idpyoidc.message import oidc |
15 | 16 | from idpyoidc.message.oauth2 import ResponseMessage |
16 | 17 | from idpyoidc.server.endpoint import Endpoint |
17 | 18 | from idpyoidc.server.exception import ClientAuthenticationError |
| 19 | +from idpyoidc.exception import ImproperlyConfigured |
18 | 20 | from idpyoidc.server.util import OAUTH2_NOCACHE_HEADERS |
19 | 21 |
|
20 | 22 | logger = logging.getLogger(__name__) |
@@ -46,18 +48,28 @@ def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = |
46 | 48 | # Add the issuer ID as an allowed JWT target |
47 | 49 | self.allowed_targets.append("") |
48 | 50 |
|
49 | | - def get_client_id_from_token(self, context, token, request=None): |
50 | | - _info = context.session_manager.get_session_info_by_token( |
| 51 | + if kwargs is None: |
| 52 | + self.config = { |
| 53 | + "policy": { |
| 54 | + "function": "/path/to/callable", |
| 55 | + "kwargs": {} |
| 56 | + }, |
| 57 | + } |
| 58 | + else: |
| 59 | + self.config = kwargs |
| 60 | + |
| 61 | + def get_client_id_from_token(self, endpoint_context, token, request=None): |
| 62 | + _info = endpoint_context.session_manager.get_session_info_by_token( |
51 | 63 | token, handler_key="access_token" |
52 | 64 | ) |
53 | 65 | return _info["client_id"] |
54 | 66 |
|
55 | 67 | def do_response( |
56 | | - self, |
57 | | - response_args: Optional[Union[Message, dict]] = None, |
58 | | - request: Optional[Union[Message, dict]] = None, |
59 | | - client_id: Optional[str] = "", |
60 | | - **kwargs |
| 68 | + self, |
| 69 | + response_args: Optional[Union[Message, dict]] = None, |
| 70 | + request: Optional[Union[Message, dict]] = None, |
| 71 | + client_id: Optional[str] = "", |
| 72 | + **kwargs |
61 | 73 | ) -> dict: |
62 | 74 |
|
63 | 75 | if "error" in kwargs and kwargs["error"]: |
@@ -157,6 +169,12 @@ def process_request(self, request=None, **kwargs): |
157 | 169 | info["sub"] = _grant.sub |
158 | 170 | if _grant.add_acr_value("userinfo"): |
159 | 171 | info["acr"] = _grant.authentication_event["authn_info"] |
| 172 | + |
| 173 | + if "userinfo" in _cntxt.cdb[request["client_id"]]: |
| 174 | + self.config["policy"] = _cntxt.cdb[request["client_id"]]["userinfo"]["policy"] |
| 175 | + |
| 176 | + if "policy" in self.config: |
| 177 | + info = self._enforce_policy(request, info, token, self.config) |
160 | 178 | else: |
161 | 179 | info = { |
162 | 180 | "error": "invalid_request", |
@@ -190,3 +208,26 @@ def parse_request(self, request, http_info=None, **kwargs): |
190 | 208 | request["access_token"] = auth_info["token"] |
191 | 209 |
|
192 | 210 | return request |
| 211 | + |
| 212 | + def _enforce_policy(self, request, response_info, token, config): |
| 213 | + policy = config["policy"] |
| 214 | + callable = policy["function"] |
| 215 | + kwargs = policy.get("kwargs", {}) |
| 216 | + |
| 217 | + if isinstance(callable, str): |
| 218 | + try: |
| 219 | + fn = importer(callable) |
| 220 | + except Exception: |
| 221 | + raise ImproperlyConfigured(f"Error importing {callable} policy callable") |
| 222 | + else: |
| 223 | + fn = callable |
| 224 | + |
| 225 | + try: |
| 226 | + return fn(request, token, response_info, **kwargs) |
| 227 | + except Exception as e: |
| 228 | + logger.error(f"Error while executing the {fn} policy callable: {e}") |
| 229 | + return self.error_cls(error="server_error", error_description="Internal server error") |
| 230 | + |
| 231 | + |
| 232 | +def validate_userinfo_policy(request, token, response_info, **kwargs): |
| 233 | + return response_info |
0 commit comments