|
1 | 1 | """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import time |
4 | 5 | from asyncio import Lock |
5 | 6 | from collections.abc import Callable |
6 | 7 | from typing import Any |
|
17 | 18 | from fastapi import HTTPException, Request |
18 | 19 |
|
19 | 20 | from authentication.interface import AuthInterface, AuthTuple |
20 | | -from authentication.utils import extract_user_token |
| 21 | +from authentication.utils import extract_user_token, record_auth_metrics |
21 | 22 | from constants import ( |
| 23 | + AUTH_MOD_JWK_TOKEN, |
22 | 24 | DEFAULT_VIRTUAL_PATH, |
23 | 25 | ) |
24 | 26 | from log import get_logger |
25 | | -from models.api.responses.error import UnauthorizedResponse |
| 27 | +from models.api.responses.error import ServiceUnavailableResponse, UnauthorizedResponse |
26 | 28 | from models.config import JwkConfiguration |
27 | 29 |
|
28 | 30 | logger = get_logger(__name__) |
@@ -141,6 +143,93 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key: |
141 | 143 | return _internal |
142 | 144 |
|
143 | 145 |
|
| 146 | +async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet: |
| 147 | + """Load the configured JWK set and record bounded auth failures.""" |
| 148 | + try: |
| 149 | + return await get_jwk_set(str(config.url)) |
| 150 | + except aiohttp.ClientError as exc: |
| 151 | + logger.error("Failed to fetch JWK set: %s", exc) |
| 152 | + record_auth_metrics( |
| 153 | + AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time |
| 154 | + ) |
| 155 | + response = ServiceUnavailableResponse( |
| 156 | + backend_name="JWK key server", |
| 157 | + cause="Unable to reach authentication key server", |
| 158 | + ) |
| 159 | + raise HTTPException(**response.model_dump()) from exc |
| 160 | + except json.JSONDecodeError as exc: |
| 161 | + logger.error("Invalid JSON in JWK set response: %s", exc) |
| 162 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time) |
| 163 | + response = ServiceUnavailableResponse( |
| 164 | + backend_name="JWK key server", |
| 165 | + cause="Authentication key server returned invalid data", |
| 166 | + ) |
| 167 | + raise HTTPException(**response.model_dump()) from exc |
| 168 | + except JoseError as exc: |
| 169 | + logger.error("Invalid JWK set format: %s", exc) |
| 170 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time) |
| 171 | + response = ServiceUnavailableResponse( |
| 172 | + backend_name="JWK key server", |
| 173 | + cause="Authentication keys are malformed", |
| 174 | + ) |
| 175 | + raise HTTPException(**response.model_dump()) from exc |
| 176 | + |
| 177 | + |
| 178 | +def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any: |
| 179 | + """Decode a JWT and record bounded auth failures.""" |
| 180 | + try: |
| 181 | + return jwt.decode(user_token, key=key_resolver_func(jwk_set)) |
| 182 | + except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: |
| 183 | + logger.warning("Token decode error: %s", exc) |
| 184 | + record_auth_metrics( |
| 185 | + AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time |
| 186 | + ) |
| 187 | + if isinstance(exc, KeyNotFoundError): |
| 188 | + cause = "Token signed by unknown key" |
| 189 | + elif isinstance(exc, BadSignatureError): |
| 190 | + cause = "Invalid token signature" |
| 191 | + elif isinstance(exc, DecodeError): |
| 192 | + cause = "Token could not be decoded" |
| 193 | + else: |
| 194 | + cause = "Token format error" |
| 195 | + response = UnauthorizedResponse(cause=cause) |
| 196 | + raise HTTPException(**response.model_dump()) from exc |
| 197 | + |
| 198 | + |
| 199 | +def _validate_jwk_claims(claims: Any, start_time: float) -> None: |
| 200 | + """Validate decoded JWT claims and record bounded auth failures.""" |
| 201 | + try: |
| 202 | + claims.validate() |
| 203 | + except ExpiredTokenError as exc: |
| 204 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time) |
| 205 | + response = UnauthorizedResponse(cause="Token has expired") |
| 206 | + raise HTTPException(**response.model_dump()) from exc |
| 207 | + except JoseError as exc: |
| 208 | + record_auth_metrics( |
| 209 | + AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time |
| 210 | + ) |
| 211 | + response = UnauthorizedResponse(cause="Token validation failed") |
| 212 | + raise HTTPException(**response.model_dump()) from exc |
| 213 | + |
| 214 | + |
| 215 | +def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str: |
| 216 | + """Return a required JWT claim and record bounded auth failures when missing.""" |
| 217 | + try: |
| 218 | + value = claims[claim_name] |
| 219 | + except KeyError as exc: |
| 220 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time) |
| 221 | + response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}") |
| 222 | + raise HTTPException(**response.model_dump()) from exc |
| 223 | + if not isinstance(value, str) or not value.strip(): |
| 224 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time) |
| 225 | + response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}") |
| 226 | + invalid_claim_error = ValueError( |
| 227 | + f"Token claim {claim_name} must be a non-empty string" |
| 228 | + ) |
| 229 | + raise HTTPException(**response.model_dump()) from invalid_claim_error |
| 230 | + return value |
| 231 | + |
| 232 | + |
144 | 233 | class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods |
145 | 234 | """JWK AuthDependency class for JWK-based JWT authentication.""" |
146 | 235 |
|
@@ -189,73 +278,40 @@ async def __call__(self, request: Request) -> AuthTuple: |
189 | 278 | extracted from the validated JWT. Only returned on successful |
190 | 279 | authentication; all error paths raise HTTPException. |
191 | 280 | """ |
| 281 | + start_time = time.monotonic() |
| 282 | + |
192 | 283 | if not request.headers.get("Authorization"): |
| 284 | + record_auth_metrics( |
| 285 | + AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time |
| 286 | + ) |
193 | 287 | response = UnauthorizedResponse(cause="No Authorization header found") |
194 | 288 | raise HTTPException(**response.model_dump()) |
195 | 289 |
|
196 | | - user_token = extract_user_token(request.headers) |
197 | | - |
198 | | - try: |
199 | | - jwk_set = await get_jwk_set(str(self.config.url)) |
200 | | - except aiohttp.ClientError as exc: |
201 | | - logger.error("Failed to fetch JWK set: %s", exc) |
202 | | - response = UnauthorizedResponse( |
203 | | - cause="Unable to reach authentication key server" |
204 | | - ) |
205 | | - raise HTTPException(**response.model_dump()) from exc |
206 | | - except json.JSONDecodeError as exc: |
207 | | - logger.error("Invalid JSON in JWK set response: %s", exc) |
208 | | - response = UnauthorizedResponse( |
209 | | - cause="Authentication key server returned invalid data" |
210 | | - ) |
211 | | - raise HTTPException(**response.model_dump()) from exc |
212 | | - except JoseError as exc: |
213 | | - logger.error("Invalid JWK set format: %s", exc) |
214 | | - response = UnauthorizedResponse(cause="Authentication keys are malformed") |
215 | | - raise HTTPException(**response.model_dump()) from exc |
216 | | - |
217 | | - try: |
218 | | - claims = jwt.decode(user_token, key=key_resolver_func(jwk_set)) |
219 | | - except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: |
220 | | - logger.warning("Token decode error: %s", exc) |
221 | | - cause_map = { |
222 | | - KeyNotFoundError: "Token signed by unknown key", |
223 | | - BadSignatureError: "Invalid token signature", |
224 | | - DecodeError: "Token could not be decoded", |
225 | | - JoseError: "Token format error", |
226 | | - } |
227 | | - response = UnauthorizedResponse( |
228 | | - cause=cause_map.get(type(exc), "Unknown token error") |
229 | | - ) |
230 | | - raise HTTPException(**response.model_dump()) from exc |
231 | | - |
232 | 290 | try: |
233 | | - claims.validate() |
234 | | - except ExpiredTokenError as exc: |
235 | | - response = UnauthorizedResponse(cause="Token has expired") |
236 | | - raise HTTPException(**response.model_dump()) from exc |
237 | | - except JoseError as exc: |
238 | | - response = UnauthorizedResponse(cause="Token validation failed") |
239 | | - raise HTTPException(**response.model_dump()) from exc |
240 | | - |
241 | | - try: |
242 | | - user_id: str = claims[self.config.jwt_configuration.user_id_claim] |
243 | | - except KeyError as exc: |
244 | | - missing_claim = self.config.jwt_configuration.user_id_claim |
245 | | - response = UnauthorizedResponse( |
246 | | - cause=f"Token missing claim: {missing_claim}" |
| 291 | + user_token = extract_user_token(request.headers) |
| 292 | + except HTTPException: |
| 293 | + record_auth_metrics( |
| 294 | + AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time |
247 | 295 | ) |
248 | | - raise HTTPException(**response.model_dump()) from exc |
249 | | - |
250 | | - try: |
251 | | - username: str = claims[self.config.jwt_configuration.username_claim] |
252 | | - except KeyError as exc: |
253 | | - missing_claim = self.config.jwt_configuration.username_claim |
254 | | - response = UnauthorizedResponse( |
255 | | - cause=f"Token missing claim: {missing_claim}" |
| 296 | + raise |
| 297 | + except Exception: # pylint: disable=broad-exception-caught |
| 298 | + logger.exception("Unexpected error while extracting JWK bearer token") |
| 299 | + record_auth_metrics( |
| 300 | + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", start_time |
256 | 301 | ) |
257 | | - raise HTTPException(**response.model_dump()) from exc |
| 302 | + raise |
| 303 | + |
| 304 | + jwk_set = await _get_jwk_set_for_auth(self.config, start_time) |
| 305 | + claims = _decode_jwk_claims(user_token, jwk_set, start_time) |
| 306 | + _validate_jwk_claims(claims, start_time) |
| 307 | + user_id = _get_required_claim( |
| 308 | + claims, self.config.jwt_configuration.user_id_claim, start_time |
| 309 | + ) |
| 310 | + username = _get_required_claim( |
| 311 | + claims, self.config.jwt_configuration.username_claim, start_time |
| 312 | + ) |
258 | 313 |
|
259 | 314 | logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) |
260 | 315 |
|
| 316 | + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time) |
261 | 317 | return user_id, username, self.skip_userid_check, user_token |
0 commit comments