55
66from functools import wraps
77from flask import request
8- from datetime import datetime
8+ from datetime import datetime , timezone
99import jwt
1010import logging
1111from policyengine_household_api .constants import VERSION
1414)
1515from policyengine_household_api .data .analytics_setup import db
1616from policyengine_household_api .data .models import Visit
17+ from policyengine_household_api .utils .config_loader import get_config_value
1718
1819logger = logging .getLogger (__name__ )
1920
2021
22+ # Cache the JWKS client so we don't re-fetch keys on every request.
23+ _jwks_client_cache : dict = {}
24+
25+
26+ def _get_jwks_client (auth0_address : str ):
27+ """Return a cached PyJWKClient for the given Auth0 domain."""
28+ if auth0_address not in _jwks_client_cache :
29+ jwks_url = f"https://{ auth0_address } /.well-known/jwks.json"
30+ _jwks_client_cache [auth0_address ] = jwt .PyJWKClient (jwks_url )
31+ return _jwks_client_cache [auth0_address ]
32+
33+
34+ def _verified_sub_claim (token : str ) -> str | None :
35+ """
36+ Return the token's ``sub`` claim if signature verification succeeds
37+ against the configured Auth0 JWKS, else ``None``.
38+
39+ If Auth0 configuration is missing (e.g. in a dev environment) the
40+ claim cannot be trusted and we return ``None`` so the caller can
41+ store a null client_id rather than an attacker-controlled value.
42+ """
43+ auth0_address = get_config_value ("auth.auth0.address" , "" )
44+ auth0_audience = get_config_value ("auth.auth0.audience" , "" )
45+ if not auth0_address or not auth0_audience :
46+ return None
47+
48+ try :
49+ signing_key = _get_jwks_client (auth0_address ).get_signing_key_from_jwt (
50+ token
51+ )
52+ claims = jwt .decode (
53+ token ,
54+ signing_key .key ,
55+ algorithms = ["RS256" ],
56+ audience = auth0_audience ,
57+ issuer = f"https://{ auth0_address } /" ,
58+ options = {"verify_signature" : True },
59+ )
60+ except Exception as e :
61+ logger .debug (f"JWT signature verification failed: { e } " )
62+ return None
63+
64+ return claims .get ("sub" )
65+
66+
2167def log_analytics_if_enabled (func ):
2268 """
2369 Decorator that logs analytics only if analytics is enabled in configuration.
@@ -45,25 +91,34 @@ def decorated_function(*args, **kwargs):
4591 # Create a record that will be emitted to the db
4692 new_visit = Visit ()
4793
48- # Pull client_id from JWT
94+ # Pull client_id from JWT. We only trust the `sub` claim
95+ # when the token signature has been verified against the
96+ # Auth0 JWKS. If verification fails (bad signature, JWKS
97+ # unreachable, etc.) we still record the visit but drop
98+ # the client_id so that attackers cannot spoof analytics
99+ # identities simply by crafting an unsigned JWT.
49100 try :
50101 auth_header = str (request .authorization )
51102 token = auth_header .split (" " )[1 ]
52- decoded_token = jwt .decode (
53- token , options = {"verify_signature" : False }
54- )
55- client_id = decoded_token ["sub" ]
56-
57- suffix_to_slice = "@clients"
58- if (
59- len (client_id ) >= len (suffix_to_slice )
60- and client_id [- len (suffix_to_slice ) :] == suffix_to_slice
61- ):
62- client_id = client_id [: - len (suffix_to_slice )]
63- new_visit .client_id = client_id
103+ client_id = _verified_sub_claim (token )
104+
105+ if client_id is None :
106+ new_visit .client_id = None
107+ else :
108+ suffix_to_slice = "@clients"
109+ if (
110+ len (client_id ) >= len (suffix_to_slice )
111+ and client_id [- len (suffix_to_slice ) :]
112+ == suffix_to_slice
113+ ):
114+ client_id = client_id [: - len (suffix_to_slice )]
115+ new_visit .client_id = client_id
64116 except Exception as e :
65117 logger .debug (f"Could not extract client_id from JWT: { e } " )
66- new_visit .client_id = "unknown"
118+ # Match the verified-fail path: a missing/unparseable
119+ # header must also be stored as NULL, never as a
120+ # sentinel string we'd have to filter out downstream.
121+ new_visit .client_id = None
67122
68123 # Set API version
69124 new_visit .api_version = VERSION
@@ -75,8 +130,9 @@ def decorated_function(*args, **kwargs):
75130 # Set content_length_bytes
76131 new_visit .content_length_bytes = request .content_length
77132
78- # Set the date and time
79- now = datetime .utcnow ()
133+ # Set the date and time (timezone-aware; utcnow() is
134+ # deprecated in Python 3.12+)
135+ now = datetime .now (timezone .utc )
80136 new_visit .datetime = now
81137
82138 # Emit the new record to the db
0 commit comments