Skip to content

Commit 609a1e5

Browse files
authored
Merge pull request #1488 from PolicyEngine/codex/reapply-thursday-rollback-changes
Reapply audit fixes and migrate Authlib 1.7 JWKS handling
2 parents 4449725 + 815f3f2 commit 609a1e5

27 files changed

Lines changed: 1005 additions & 87 deletions

.env-example

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
FLASK_DEBUG=1
2-
CACHE_REDIS_HOST=redis
2+
CACHE_REDIS_HOST=redis
3+
4+
# Optional: wipe the local sqlite analytics DB on startup. Only
5+
# consulted when FLASK_DEBUG=1 and analytics is enabled. Default off
6+
# so captured debug data is not lost across restarts.
7+
# RESET_ANALYTICS=1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Migrate Auth0 JWKS validation to joserfc keys so Authlib 1.7.0 works without the temporary version cap.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Reapply selected audit security and correctness fixes that were rolled back in the 0.13.13 restoration.

config/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ The following endpoints remain unprotected:
270270
- When enabled, all protected endpoints validate JWT tokens against Auth0's JWKS
271271
- The Auth0 domain and audience must match the configured values
272272

273+
## Analytics reset (debug only)
274+
275+
`RESET_ANALYTICS=1` (or `analytics.reset: true` in YAML) wipes the
276+
local SQLite analytics DB on startup. This is **only** consulted when
277+
`FLASK_DEBUG=1`; production never resets the analytics DB.
278+
273279
## Usage Examples
274280

275281
### Production Deployment (Current)

policyengine_household_api/api.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44

55
# Python imports
66
import os
7-
from pathlib import Path
87

98
# External imports
109
from flask_cors import CORS
1110
import flask
12-
from flask_sqlalchemy import SQLAlchemy
13-
from sqlalchemy.orm import DeclarativeBase
14-
from dotenv import load_dotenv
1511
from flask_limiter import Limiter
1612
from flask_limiter.util import get_remote_address
1713
from policyengine_household_api.data.analytics_setup import (
@@ -20,7 +16,6 @@
2016

2117
# Internal imports
2218
from .decorators.auth import create_auth_decorator
23-
from .constants import VERSION, REPO
2419
from policyengine_household_api.decorators.analytics import (
2520
log_analytics_if_enabled,
2621
)
@@ -40,6 +35,15 @@
4035

4136
app = application = flask.Flask(__name__)
4237

38+
# Reject absurdly large request bodies before any view runs. 10 MiB is
39+
# well above the largest legitimate household payload we have seen
40+
# (axes scans push a few hundred KiB) while still capping the memory a
41+
# single attacker can force us to allocate. Overridable via the
42+
# ``MAX_CONTENT_LENGTH`` env var (bytes).
43+
app.config["MAX_CONTENT_LENGTH"] = int(
44+
os.getenv("MAX_CONTENT_LENGTH", 10 * 1024 * 1024)
45+
)
46+
4347
CORS(app)
4448

4549
# Use in-memory storage for rate limiting
@@ -59,6 +63,7 @@
5963

6064
@app.route("/<country_id>/calculate", methods=["POST"])
6165
@require_auth_if_enabled()
66+
@limiter.limit("60 per minute")
6267
@log_analytics_if_enabled
6368
def calculate(country_id):
6469
return get_calculate(country_id)
@@ -84,8 +89,11 @@ def readiness_check():
8489
)
8590

8691

92+
# Note: `/calculate_demo` is intentionally public (documented in
93+
# config/README.md). It is guarded by a conservative rate limit rather
94+
# than JWT authentication.
8795
@app.route("/<country_id>/calculate_demo", methods=["POST"])
88-
@limiter.limit("1 per second")
96+
@limiter.limit("1 per 10 seconds")
8997
def calculate_demo(country_id):
9098
return get_calculate(country_id)
9199

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,114 @@
11
import json
2+
import logging
3+
import time
4+
from threading import Lock
25
from urllib.request import urlopen
36

47
from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
5-
from authlib.jose.rfc7517.jwk import JsonWebKey
8+
from joserfc.jwk import KeySet
9+
10+
logger = logging.getLogger(__name__)
11+
12+
JWKS_FETCH_TIMEOUT = 10 # seconds
13+
# Minimum wait between back-to-back lazy retries after a failure.
14+
# Keeps us from hammering Auth0 when it is actively degraded.
15+
JWKS_RETRY_INTERVAL_SECONDS = 30
16+
17+
18+
# Module-level cache of successful JWKS fetches, keyed by issuer. Only
19+
# successes are cached so that a transient failure is retried on the
20+
# next authenticated request (``lru_cache`` would have memoised the
21+
# ``None`` return, making the "lazy retry" dead code).
22+
_jwks_cache: dict = {}
23+
# Records the monotonic timestamp of the most recent *failed* fetch
24+
# per-issuer so we can rate-limit retries without caching the failure
25+
# itself.
26+
_jwks_last_failure: dict = {}
27+
_jwks_lock = Lock()
28+
29+
30+
def _fetch_jwks_uncached(issuer: str):
31+
"""Fetch the JWKS for an Auth0 issuer, bypassing the cache.
32+
33+
Returns a joserfc key set on success, ``None`` on failure. Errors
34+
are logged rather than raised so that a transient Auth0 outage
35+
doesn't crash the process at import time.
36+
"""
37+
jwks_url = f"{issuer}.well-known/jwks.json"
38+
try:
39+
with urlopen(jwks_url, timeout=JWKS_FETCH_TIMEOUT) as response:
40+
return KeySet.import_key_set(json.loads(response.read()))
41+
except Exception as e:
42+
logger.warning(f"Failed to fetch JWKS from {jwks_url}: {e}")
43+
return None
44+
45+
46+
def _fetch_jwks(issuer: str):
47+
"""Fetch JWKS, caching only successful results.
48+
49+
On failure we record the time but do not memoise the ``None`` — a
50+
later call will retry (subject to ``JWKS_RETRY_INTERVAL_SECONDS``
51+
backoff) so that the validator self-heals once Auth0 recovers.
52+
"""
53+
with _jwks_lock:
54+
cached = _jwks_cache.get(issuer)
55+
if cached is not None:
56+
return cached
57+
last_failure = _jwks_last_failure.get(issuer)
58+
if (
59+
last_failure is not None
60+
and time.monotonic() - last_failure < JWKS_RETRY_INTERVAL_SECONDS
61+
):
62+
# Too soon after the last failure — don't hammer Auth0.
63+
return None
64+
65+
# Fetch outside the lock so a slow network call doesn't block
66+
# other threads that might be serving requests with a cached key.
67+
key_set = _fetch_jwks_uncached(issuer)
68+
69+
with _jwks_lock:
70+
if key_set is not None:
71+
_jwks_cache[issuer] = key_set
72+
_jwks_last_failure.pop(issuer, None)
73+
else:
74+
_jwks_last_failure[issuer] = time.monotonic()
75+
return key_set
76+
77+
78+
def _clear_jwks_cache():
79+
"""Test helper: wipe the success/failure caches."""
80+
with _jwks_lock:
81+
_jwks_cache.clear()
82+
_jwks_last_failure.clear()
683

784

885
class Auth0JWTBearerTokenValidator(JWTBearerTokenValidator):
986
def __init__(self, domain, audience):
1087
issuer = f"https://{domain}/"
11-
jsonurl = urlopen(f"{issuer}.well-known/jwks.json")
12-
public_key = JsonWebKey.import_key_set(json.loads(jsonurl.read()))
88+
89+
public_key = _fetch_jwks(issuer)
90+
if public_key is None:
91+
# Retry on next token validation rather than failing hard
92+
# at construction time. A missing key set means token
93+
# validation will fail cleanly inside authlib.
94+
logger.warning(
95+
"JWKS unavailable at construction; will retry on first "
96+
"token validation."
97+
)
98+
1399
super(Auth0JWTBearerTokenValidator, self).__init__(public_key)
100+
self._issuer = issuer
14101
self.claims_options = {
15102
"exp": {"essential": True},
16103
"aud": {"essential": True, "value": audience},
17104
"iss": {"essential": True, "value": issuer},
18105
}
106+
107+
def authenticate_token(self, token_string):
108+
# Lazy-refresh the JWKS if the initial fetch failed. Because
109+
# ``_fetch_jwks`` only caches successes, this call will retry
110+
# the network fetch (subject to a short backoff) until Auth0
111+
# responds.
112+
if self.public_key is None:
113+
self.public_key = _fetch_jwks(self._issuer)
114+
return super().authenticate_token(token_string)

policyengine_household_api/country.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import logging
23
from flask import Response
34
import json
45
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
@@ -431,8 +432,12 @@ def calculate(
431432

432433
return household, None
433434

434-
except Exception as e:
435-
print(f"Error computing tracer output: {e}")
435+
except Exception:
436+
# Re-raise so endpoints/household.py (which unpacks
437+
# ``(result, computation_tree_uuid)``) can surface a real
438+
# 500 instead of a TypeError on ``None`` unpacking.
439+
logging.exception("Tracer failed while computing household")
440+
raise
436441

437442

438443
def create_policy_reform(policy_data: dict) -> dict:
@@ -477,7 +482,7 @@ def apply(self):
477482

478483

479484
def get_requested_computations(household: dict):
480-
requested_computations = dpath.util.search(
485+
requested_computations = dpath.search(
481486
household,
482487
"*/*/*/*",
483488
afilter=lambda t: t is None,

policyengine_household_api/data/analytics_setup.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,21 @@ def initialize_analytics_db_if_enabled(app):
4747
db_url = (
4848
REPO / "policyengine_household_api" / "data" / "policyengine.db"
4949
)
50-
if Path(db_url).exists():
50+
# Only wipe the analytics DB when explicitly requested via
51+
# RESET_ANALYTICS=1 (or the ``analytics.reset`` config flag).
52+
should_reset = os.getenv("RESET_ANALYTICS", "").lower() in (
53+
"1",
54+
"true",
55+
"yes",
56+
) or get_config_value("analytics.reset", False)
57+
if should_reset and Path(db_url).exists():
5158
Path(db_url).unlink()
5259
if not Path(db_url).exists():
5360
Path(db_url).touch()
54-
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:////" + str(db_url)
61+
# sqlite: absolute paths require exactly three slashes plus the
62+
# leading "/" from the absolute path (=> "sqlite:////tmp/x.db").
63+
# db_url here is already absolute, so use an f-string.
64+
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_url}"
5565
else:
5666
app.config["SQLALCHEMY_DATABASE_URI"] = "mysql+pymysql://"
5767
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"creator": getconn}

policyengine_household_api/decorators/analytics.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from functools import wraps
77
from flask import request
8-
from datetime import datetime
8+
from datetime import datetime, timezone
99
import jwt
1010
import logging
1111
from policyengine_household_api.constants import VERSION
@@ -14,10 +14,56 @@
1414
)
1515
from policyengine_household_api.data.analytics_setup import db
1616
from policyengine_household_api.data.models import Visit
17+
from policyengine_household_api.utils.config_loader import get_config_value
1718

1819
logger = logging.getLogger(__name__)
1920

2021

22+
# Cache the JWKS client so we don't re-fetch keys on every request.
23+
_jwks_client_cache: dict = {}
24+
25+
26+
def _get_jwks_client(auth0_address: str):
27+
"""Return a cached PyJWKClient for the given Auth0 domain."""
28+
if auth0_address not in _jwks_client_cache:
29+
jwks_url = f"https://{auth0_address}/.well-known/jwks.json"
30+
_jwks_client_cache[auth0_address] = jwt.PyJWKClient(jwks_url)
31+
return _jwks_client_cache[auth0_address]
32+
33+
34+
def _verified_sub_claim(token: str) -> str | None:
35+
"""
36+
Return the token's ``sub`` claim if signature verification succeeds
37+
against the configured Auth0 JWKS, else ``None``.
38+
39+
If Auth0 configuration is missing (e.g. in a dev environment) the
40+
claim cannot be trusted and we return ``None`` so the caller can
41+
store a null client_id rather than an attacker-controlled value.
42+
"""
43+
auth0_address = get_config_value("auth.auth0.address", "")
44+
auth0_audience = get_config_value("auth.auth0.audience", "")
45+
if not auth0_address or not auth0_audience:
46+
return None
47+
48+
try:
49+
signing_key = _get_jwks_client(auth0_address).get_signing_key_from_jwt(
50+
token
51+
)
52+
claims = jwt.decode(
53+
token,
54+
signing_key.key,
55+
algorithms=["RS256"],
56+
audience=auth0_audience,
57+
issuer=f"https://{auth0_address}/",
58+
options={"verify_signature": True},
59+
)
60+
except Exception as e:
61+
logger.debug(f"JWT signature verification failed: {e}")
62+
return None
63+
64+
return claims.get("sub")
65+
66+
2167
def log_analytics_if_enabled(func):
2268
"""
2369
Decorator that logs analytics only if analytics is enabled in configuration.
@@ -45,25 +91,34 @@ def decorated_function(*args, **kwargs):
4591
# Create a record that will be emitted to the db
4692
new_visit = Visit()
4793

48-
# Pull client_id from JWT
94+
# Pull client_id from JWT. We only trust the `sub` claim
95+
# when the token signature has been verified against the
96+
# Auth0 JWKS. If verification fails (bad signature, JWKS
97+
# unreachable, etc.) we still record the visit but drop
98+
# the client_id so that attackers cannot spoof analytics
99+
# identities simply by crafting an unsigned JWT.
49100
try:
50101
auth_header = str(request.authorization)
51102
token = auth_header.split(" ")[1]
52-
decoded_token = jwt.decode(
53-
token, options={"verify_signature": False}
54-
)
55-
client_id = decoded_token["sub"]
56-
57-
suffix_to_slice = "@clients"
58-
if (
59-
len(client_id) >= len(suffix_to_slice)
60-
and client_id[-len(suffix_to_slice) :] == suffix_to_slice
61-
):
62-
client_id = client_id[: -len(suffix_to_slice)]
63-
new_visit.client_id = client_id
103+
client_id = _verified_sub_claim(token)
104+
105+
if client_id is None:
106+
new_visit.client_id = None
107+
else:
108+
suffix_to_slice = "@clients"
109+
if (
110+
len(client_id) >= len(suffix_to_slice)
111+
and client_id[-len(suffix_to_slice) :]
112+
== suffix_to_slice
113+
):
114+
client_id = client_id[: -len(suffix_to_slice)]
115+
new_visit.client_id = client_id
64116
except Exception as e:
65117
logger.debug(f"Could not extract client_id from JWT: {e}")
66-
new_visit.client_id = "unknown"
118+
# Match the verified-fail path: a missing/unparseable
119+
# header must also be stored as NULL, never as a
120+
# sentinel string we'd have to filter out downstream.
121+
new_visit.client_id = None
67122

68123
# Set API version
69124
new_visit.api_version = VERSION
@@ -75,8 +130,9 @@ def decorated_function(*args, **kwargs):
75130
# Set content_length_bytes
76131
new_visit.content_length_bytes = request.content_length
77132

78-
# Set the date and time
79-
now = datetime.utcnow()
133+
# Set the date and time (timezone-aware; utcnow() is
134+
# deprecated in Python 3.12+)
135+
now = datetime.now(timezone.utc)
80136
new_visit.datetime = now
81137

82138
# Emit the new record to the db

0 commit comments

Comments
 (0)