|
1 | 1 | import json |
2 | | -import logging |
3 | | -import time |
4 | | -from threading import Lock |
5 | 2 | from urllib.request import urlopen |
6 | 3 |
|
7 | 4 | from authlib.oauth2.rfc7523 import JWTBearerTokenValidator |
8 | 5 | from authlib.jose.rfc7517.jwk import JsonWebKey |
9 | 6 |
|
10 | | -logger = logging.getLogger(__name__) |
11 | | - |
12 | | -JWKS_FETCH_TIMEOUT = 10 # seconds |
13 | | -# Minimum wait between back-to-back lazy retries after a failure. |
14 | | -# Keeps us from hammering Auth0 when it is actively degraded. |
15 | | -JWKS_RETRY_INTERVAL_SECONDS = 30 |
16 | | - |
17 | | - |
18 | | -# Module-level cache of successful JWKS fetches, keyed by issuer. Only |
19 | | -# successes are cached so that a transient failure is retried on the |
20 | | -# next authenticated request (``lru_cache`` would have memoised the |
21 | | -# ``None`` return, making the "lazy retry" dead code). |
22 | | -_jwks_cache: dict = {} |
23 | | -# Records the monotonic timestamp of the most recent *failed* fetch |
24 | | -# per-issuer so we can rate-limit retries without caching the failure |
25 | | -# itself. |
26 | | -_jwks_last_failure: dict = {} |
27 | | -_jwks_lock = Lock() |
28 | | - |
29 | | - |
30 | | -def _fetch_jwks_uncached(issuer: str): |
31 | | - """Fetch the JWKS for an Auth0 issuer, bypassing the cache. |
32 | | -
|
33 | | - Returns an authlib key set on success, ``None`` on failure. Errors |
34 | | - are logged rather than raised so that a transient Auth0 outage |
35 | | - doesn't crash the process at import time. |
36 | | - """ |
37 | | - jwks_url = f"{issuer}.well-known/jwks.json" |
38 | | - try: |
39 | | - with urlopen(jwks_url, timeout=JWKS_FETCH_TIMEOUT) as response: |
40 | | - return JsonWebKey.import_key_set(json.loads(response.read())) |
41 | | - except Exception as e: |
42 | | - logger.warning(f"Failed to fetch JWKS from {jwks_url}: {e}") |
43 | | - return None |
44 | | - |
45 | | - |
46 | | -def _fetch_jwks(issuer: str): |
47 | | - """Fetch JWKS, caching only successful results. |
48 | | -
|
49 | | - On failure we record the time but do not memoise the ``None`` — a |
50 | | - later call will retry (subject to ``JWKS_RETRY_INTERVAL_SECONDS`` |
51 | | - backoff) so that the validator self-heals once Auth0 recovers. |
52 | | - """ |
53 | | - with _jwks_lock: |
54 | | - cached = _jwks_cache.get(issuer) |
55 | | - if cached is not None: |
56 | | - return cached |
57 | | - last_failure = _jwks_last_failure.get(issuer) |
58 | | - if ( |
59 | | - last_failure is not None |
60 | | - and time.monotonic() - last_failure < JWKS_RETRY_INTERVAL_SECONDS |
61 | | - ): |
62 | | - # Too soon after the last failure — don't hammer Auth0. |
63 | | - return None |
64 | | - |
65 | | - # Fetch outside the lock so a slow network call doesn't block |
66 | | - # other threads that might be serving requests with a cached key. |
67 | | - key_set = _fetch_jwks_uncached(issuer) |
68 | | - |
69 | | - with _jwks_lock: |
70 | | - if key_set is not None: |
71 | | - _jwks_cache[issuer] = key_set |
72 | | - _jwks_last_failure.pop(issuer, None) |
73 | | - else: |
74 | | - _jwks_last_failure[issuer] = time.monotonic() |
75 | | - return key_set |
76 | | - |
77 | | - |
78 | | -def _clear_jwks_cache(): |
79 | | - """Test helper: wipe the success/failure caches.""" |
80 | | - with _jwks_lock: |
81 | | - _jwks_cache.clear() |
82 | | - _jwks_last_failure.clear() |
83 | | - |
84 | 7 |
|
85 | 8 | class Auth0JWTBearerTokenValidator(JWTBearerTokenValidator): |
86 | 9 | def __init__(self, domain, audience): |
87 | 10 | issuer = f"https://{domain}/" |
88 | | - |
89 | | - public_key = _fetch_jwks(issuer) |
90 | | - if public_key is None: |
91 | | - # Retry on next token validation rather than failing hard |
92 | | - # at construction time. A missing key set means token |
93 | | - # validation will fail cleanly inside authlib. |
94 | | - logger.warning( |
95 | | - "JWKS unavailable at construction; will retry on first " |
96 | | - "token validation." |
97 | | - ) |
98 | | - |
| 11 | + jsonurl = urlopen(f"{issuer}.well-known/jwks.json") |
| 12 | + public_key = JsonWebKey.import_key_set(json.loads(jsonurl.read())) |
99 | 13 | super(Auth0JWTBearerTokenValidator, self).__init__(public_key) |
100 | | - self._issuer = issuer |
101 | 14 | self.claims_options = { |
102 | 15 | "exp": {"essential": True}, |
103 | 16 | "aud": {"essential": True, "value": audience}, |
104 | 17 | "iss": {"essential": True, "value": issuer}, |
105 | 18 | } |
106 | | - |
107 | | - def authenticate_token(self, token_string): |
108 | | - # Lazy-refresh the JWKS if the initial fetch failed. Because |
109 | | - # ``_fetch_jwks`` only caches successes, this call will retry |
110 | | - # the network fetch (subject to a short backoff) until Auth0 |
111 | | - # responds. |
112 | | - if self.public_key is None: |
113 | | - self.public_key = _fetch_jwks(self._issuer) |
114 | | - return super().authenticate_token(token_string) |
0 commit comments