|
7 | 7 | except ImportError: |
8 | 8 | from urlparse import parse_qs, urlparse, urlunparse |
9 | 9 | from urllib import urlencode, quote_plus |
| 10 | +import inspect |
10 | 11 | import logging |
11 | 12 | import warnings |
12 | 13 | import time |
@@ -104,6 +105,11 @@ def __init__( |
104 | 105 | or a raw JWT assertion in bytes (which we will relay to http layer). |
105 | 106 | It can also be a callable (recommended), |
106 | 107 | so that we will do lazy creation of an assertion. |
| 108 | +
|
| 109 | + The callable may accept zero arguments (legacy) or one argument. |
| 110 | + When it accepts one argument, it will receive a dict containing |
| 111 | + ``"client_id"``, ``"token_endpoint"``, and optionally ``"fmi_path"`` |
| 112 | + (when an FMI path is set on the current request). |
107 | 113 | client_assertion_type (str): |
108 | 114 | The type of your :attr:`client_assertion` parameter. |
109 | 115 | It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or |
@@ -168,6 +174,35 @@ def __init__( |
168 | 174 | # A workaround for requests not supporting session-wide timeout |
169 | 175 | self._http_client.request, timeout=timeout) |
170 | 176 |
|
| 177 | + @staticmethod |
| 178 | + def _accepts_context(func): |
| 179 | + """Check if a callable accepts at least one positional argument.""" |
| 180 | + try: |
| 181 | + sig = inspect.signature(func) |
| 182 | + params = [ |
| 183 | + p for p in sig.parameters.values() |
| 184 | + if p.kind in ( |
| 185 | + inspect.Parameter.POSITIONAL_ONLY, |
| 186 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 187 | + ) |
| 188 | + ] |
| 189 | + return len(params) >= 1 |
| 190 | + except (ValueError, TypeError): |
| 191 | + return False # Signature not inspectable; treat as zero-arg |
| 192 | + |
| 193 | + def _invoke_assertion_callable(self, assertion_callable, data=None): |
| 194 | + """Invoke an assertion callable, passing context if it accepts one.""" |
| 195 | + if self._accepts_context(assertion_callable): |
| 196 | + context = { |
| 197 | + "client_id": self.client_id, |
| 198 | + "token_endpoint": self.configuration.get( |
| 199 | + "token_endpoint", ""), |
| 200 | + } |
| 201 | + if data and data.get("fmi_path"): |
| 202 | + context["fmi_path"] = data["fmi_path"] |
| 203 | + return assertion_callable(context) |
| 204 | + return assertion_callable() |
| 205 | + |
171 | 206 | def _build_auth_request_params(self, response_type, **kwargs): |
172 | 207 | # response_type is a string defined in |
173 | 208 | # https://tools.ietf.org/html/rfc6749#section-3.1.1 |
@@ -198,11 +233,11 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 |
198 | 233 | # See https://tools.ietf.org/html/rfc7521#section-4.2 |
199 | 234 | encoder = self.client_assertion_encoders.get( |
200 | 235 | self.default_body["client_assertion_type"], lambda a: a) |
201 | | - _data["client_assertion"] = encoder( |
202 | | - self.client_assertion() # Do lazy on-the-fly computation |
203 | | - if callable(self.client_assertion) else self.client_assertion |
204 | | - ) # The type is bytes, which is preferable. See also: |
205 | | - # https://github.com/psf/requests/issues/4503#issuecomment-455001070 |
| 236 | + if callable(self.client_assertion): |
| 237 | + raw = self._invoke_assertion_callable(self.client_assertion, data) |
| 238 | + else: |
| 239 | + raw = self.client_assertion |
| 240 | + _data["client_assertion"] = encoder(raw) |
206 | 241 |
|
207 | 242 | _data.update(self.default_body) # It may contain authen parameters |
208 | 243 | _data.update(data or {}) # So the content in data param prevails |
@@ -770,6 +805,34 @@ class initialization. |
770 | 805 | data.update(scope=scope) |
771 | 806 | return self._obtain_token("client_credentials", data=data, **kwargs) |
772 | 807 |
|
| 808 | + def obtain_token_by_user_fic( |
| 809 | + self, scope, assertion, username=None, user_object_id=None, |
| 810 | + **kwargs): |
| 811 | + """Obtain token using the ``user_fic`` grant type. |
| 812 | +
|
| 813 | + This exchanges a federated identity credential (e.g. an agent |
| 814 | + instance token) for a user-scoped access token. |
| 815 | +
|
| 816 | + :param scope: Scopes for the target resource (already decorated |
| 817 | + with OIDC scopes by the caller). |
| 818 | + :param str assertion: The federated identity credential token. |
| 819 | + :param str username: The target user's UPN (mutually exclusive |
| 820 | + with *user_object_id*). |
| 821 | + :param str user_object_id: The target user's Object ID (mutually |
| 822 | + exclusive with *username*). |
| 823 | + """ |
| 824 | + data = kwargs.pop("data", {}) |
| 825 | + data.update( |
| 826 | + scope=scope, |
| 827 | + user_federated_identity_credential=assertion, |
| 828 | + client_info="1", |
| 829 | + ) |
| 830 | + if user_object_id: |
| 831 | + data["user_id"] = str(user_object_id) |
| 832 | + elif username: |
| 833 | + data["username"] = username |
| 834 | + return self._obtain_token("user_fic", data=data, **kwargs) |
| 835 | + |
773 | 836 | def __init__(self, |
774 | 837 | server_configuration, client_id, |
775 | 838 | on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...) |
|
0 commit comments