Skip to content

Commit db6e5d5

Browse files
refactor(auth): modularize paseto internals
1 parent 1900c55 commit db6e5d5

9 files changed

Lines changed: 876 additions & 765 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Internal helpers for the public ``fastapi_paseto`` package."""
Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
"""Internal ``AuthPASETO`` implementation."""
2+
3+
import base64
4+
import json
5+
from collections.abc import Sequence
6+
from datetime import datetime, timedelta
7+
8+
from fastapi import Depends, Request, Response
9+
from pyseto import Key, Paseto, Token
10+
from pyseto.exceptions import DecryptError, SignError, VerifyError
11+
12+
from fastapi_paseto.auth_config import AuthConfig
13+
from fastapi_paseto.exceptions import (
14+
FreshTokenRequired,
15+
MissingTokenError,
16+
PASETODecodeError,
17+
RevokedTokenError,
18+
)
19+
20+
from .request import (
21+
extract_token_from_header,
22+
extract_token_from_json,
23+
get_request_json,
24+
)
25+
from .token_helpers import (
26+
build_custom_claims,
27+
build_reserved_claims,
28+
decode_base64_token,
29+
generate_token_identifier,
30+
parse_token_purpose,
31+
parse_token_version,
32+
resolve_expiry_seconds,
33+
resolve_secret_key,
34+
split_token_parts,
35+
validate_required_token_flags,
36+
validate_token_creation_arguments,
37+
validate_token_type,
38+
)
39+
40+
41+
class AuthPASETO(AuthConfig):
42+
"""FastAPI dependency that creates and validates PASETO tokens."""
43+
44+
def __init__(
45+
self,
46+
request: Request = None,
47+
response: Response = None,
48+
request_json: dict[str, object] = Depends(get_request_json),
49+
) -> None:
50+
"""Capture request-scoped objects used by token extraction helpers."""
51+
52+
self._request_json = request_json
53+
self._response = response
54+
if request is not None:
55+
self._request = request
56+
57+
def _get_paseto_from_json(
58+
self,
59+
json_key: str | None = None,
60+
json_type: str | None = None,
61+
) -> str | None:
62+
"""Extract a token from the request JSON body."""
63+
64+
return extract_token_from_json(
65+
request_json=self._request_json,
66+
json_key=json_key or self._json_key,
67+
json_type=json_type or self._json_type,
68+
)
69+
70+
def _get_paseto_from_header(
71+
self,
72+
header_name: str | None = None,
73+
header_type: str | None = None,
74+
) -> str | None:
75+
"""Extract a token from the configured request header."""
76+
77+
return extract_token_from_header(
78+
header_value=self._request.headers.get(self._header_name),
79+
header_name=header_name or self._header_name,
80+
header_type=header_type or self._header_type,
81+
)
82+
83+
def _get_paseto_identifier(self) -> str:
84+
"""Return a new unique token identifier."""
85+
86+
return generate_token_identifier()
87+
88+
def _get_secret_key(self, purpose: str, process: str) -> str:
89+
"""Return the configured key for the requested cryptographic operation."""
90+
91+
return resolve_secret_key(
92+
purpose=purpose,
93+
process=process,
94+
secret_key=self._secret_key,
95+
private_key=self._private_key,
96+
public_key=self._public_key,
97+
)
98+
99+
def _get_int_from_datetime(self, value: datetime) -> int:
100+
"""Convert a datetime to whole seconds since the Unix epoch."""
101+
102+
if not isinstance(value, datetime): # pragma: no cover
103+
raise TypeError("a datetime is required")
104+
return int(value.timestamp())
105+
106+
def _create_token(
107+
self,
108+
subject: str | int,
109+
type_token: str,
110+
exp_seconds: int,
111+
fresh: bool | None = None,
112+
issuer: str | None = None,
113+
purpose: str | None = None,
114+
audience: str | Sequence[str] = "",
115+
user_claims: dict[str, object] | None = None,
116+
version: int | None = None,
117+
base64_encode: bool = False,
118+
) -> str:
119+
"""Create and return an encoded PASETO string."""
120+
121+
validate_token_creation_arguments(
122+
subject=subject,
123+
fresh=fresh,
124+
audience=audience,
125+
purpose=purpose,
126+
version=version,
127+
user_claims=user_claims,
128+
)
129+
user_claims = user_claims or {}
130+
issuer = issuer or self._encode_issuer
131+
purpose = purpose or self._purpose
132+
version = version or self._version
133+
134+
if purpose not in ("local", "public"):
135+
raise ValueError("Purpose must be local or public.")
136+
137+
claims = {
138+
**build_reserved_claims(subject),
139+
**build_custom_claims(type_token, fresh, issuer, audience),
140+
**user_claims,
141+
}
142+
secret_key = self._get_secret_key(purpose, "encode")
143+
paseto = Paseto.new(exp=exp_seconds, include_iat=True)
144+
encoding_key = Key.new(version=version, purpose=purpose, key=secret_key)
145+
token = paseto.encode(encoding_key, claims, serializer=json)
146+
if base64_encode:
147+
token = base64.b64encode(token)
148+
return token.decode("utf-8")
149+
150+
def _has_token_in_denylist_callback(self) -> bool:
151+
"""Return whether a denylist callback has been configured."""
152+
153+
return self._token_in_denylist_callback is not None
154+
155+
def _check_token_is_revoked(self, payload: dict[str, object]) -> None:
156+
"""Raise if the decoded token has been revoked via the configured callback."""
157+
158+
if not self._denylist_enabled:
159+
return
160+
if not self._has_token_in_denylist_callback():
161+
raise RuntimeError(
162+
"A token_in_denylist_callback must be provided via "
163+
"the '@AuthPASETO.token_in_denylist_loader' if "
164+
"authpaseto_denylist_enabled is 'True'"
165+
)
166+
if self._token_in_denylist_callback.__func__(payload):
167+
raise RevokedTokenError(status_code=401, message="Token has been revoked")
168+
169+
def _get_expiry_seconds(
170+
self,
171+
type_token: str,
172+
expires_time: timedelta | datetime | int | bool | None = None,
173+
) -> int:
174+
"""Resolve an expiry configuration into the seconds expected by ``pyseto``."""
175+
176+
return resolve_expiry_seconds(
177+
type_token=type_token,
178+
expires_time=expires_time,
179+
access_expires=self._access_token_expires,
180+
refresh_expires=self._refresh_token_expires,
181+
other_expires=self._other_token_expires,
182+
)
183+
184+
def create_access_token(
185+
self,
186+
subject: str | int,
187+
fresh: bool = False,
188+
purpose: str | None = None,
189+
expires_time: timedelta | datetime | int | bool | None = None,
190+
audience: str | Sequence[str] | None = None,
191+
user_claims: dict[str, object] | None = None,
192+
base64_encode: bool = False,
193+
) -> str:
194+
"""Create a new access token."""
195+
196+
return self._create_token(
197+
subject=subject,
198+
type_token="access",
199+
exp_seconds=self._get_expiry_seconds("access", expires_time),
200+
fresh=fresh,
201+
purpose=purpose,
202+
audience=audience,
203+
user_claims=user_claims,
204+
issuer=self._encode_issuer,
205+
base64_encode=base64_encode,
206+
)
207+
208+
def create_refresh_token(
209+
self,
210+
subject: str | int,
211+
purpose: str | None = None,
212+
expires_time: timedelta | datetime | int | bool | None = None,
213+
audience: str | Sequence[str] | None = None,
214+
user_claims: dict[str, object] | None = None,
215+
base64_encode: bool = False,
216+
) -> str:
217+
"""Create a new refresh token."""
218+
219+
return self._create_token(
220+
subject=subject,
221+
type_token="refresh",
222+
exp_seconds=self._get_expiry_seconds("refresh", expires_time),
223+
purpose=purpose,
224+
audience=audience,
225+
user_claims=user_claims,
226+
base64_encode=base64_encode,
227+
)
228+
229+
def create_token(
230+
self,
231+
subject: str | int,
232+
type: str,
233+
purpose: str | None = None,
234+
expires_time: timedelta | datetime | int | bool | None = None,
235+
audience: str | Sequence[str] | None = None,
236+
user_claims: dict[str, object] | None = None,
237+
base64_encode: bool = False,
238+
) -> str:
239+
"""Create a token with a caller-provided custom type."""
240+
241+
return self._create_token(
242+
subject=subject,
243+
type_token=type,
244+
exp_seconds=self._get_expiry_seconds(type, expires_time),
245+
purpose=purpose,
246+
audience=audience,
247+
user_claims=user_claims,
248+
base64_encode=base64_encode,
249+
)
250+
251+
def _get_token_version(self) -> int:
252+
"""Return the parsed version of the current token."""
253+
254+
return parse_token_version(self._get_raw_token_parts()[0])
255+
256+
def _get_token_purpose(self) -> str:
257+
"""Return the parsed purpose of the current token."""
258+
259+
return parse_token_purpose(self._get_raw_token_parts()[1])
260+
261+
def _get_raw_token_parts(self) -> list[str]:
262+
"""Return and cache the dot-separated parts of the current token."""
263+
264+
if self._token_parts:
265+
return self._token_parts
266+
self._token_parts = split_token_parts(self._token)
267+
return self._token_parts
268+
269+
def _decode_token(self, base64_encoded: bool = False) -> Token:
270+
"""Decode and validate the current token."""
271+
272+
if base64_encoded:
273+
self._token = decode_base64_token(self._token)
274+
275+
purpose = self._get_token_purpose()
276+
version = self._get_token_version()
277+
secret_key = self._get_secret_key(purpose=purpose, process="decode")
278+
decoding_key = Key.new(version=version, purpose=purpose, key=secret_key)
279+
280+
try:
281+
paseto = Paseto.new(leeway=self._decode_leeway)
282+
token = paseto.decode(
283+
keys=decoding_key,
284+
token=self._token,
285+
deserializer=json,
286+
aud=self._decode_audience,
287+
)
288+
except (DecryptError, SignError, VerifyError) as err:
289+
raise PASETODecodeError(status_code=422, message=str(err))
290+
291+
if self._decode_issuer:
292+
if "iss" not in token.payload:
293+
raise PASETODecodeError(
294+
status_code=422,
295+
message="Token is missing the 'iss' claim",
296+
)
297+
if token.payload["iss"] != self._decode_issuer:
298+
raise PASETODecodeError(
299+
status_code=422,
300+
message="Token issuer is not valid",
301+
)
302+
303+
self._check_token_is_revoked(token.payload)
304+
self._decoded_token = token
305+
if "sub" in token.payload:
306+
self._current_user = token.payload["sub"]
307+
return token
308+
309+
def get_token_payload(self) -> dict[str, object] | None:
310+
"""Return the decoded token payload for the current request."""
311+
312+
if self._decoded_token:
313+
return self._decoded_token.payload
314+
return None
315+
316+
def get_jti(self) -> str | None:
317+
"""Return the current token identifier if present."""
318+
319+
payload = self.get_token_payload()
320+
if payload and "jti" in payload:
321+
return payload["jti"]
322+
return None
323+
324+
def get_paseto_subject(self) -> str | int | None:
325+
"""Return the current decoded token subject if present."""
326+
327+
payload = self.get_token_payload()
328+
if payload and "sub" in payload:
329+
return payload["sub"]
330+
return None
331+
332+
def get_subject(self) -> str | int | None:
333+
"""Return the cached subject captured during token validation."""
334+
335+
return self._current_user
336+
337+
def paseto_required(
338+
self,
339+
optional: bool = False,
340+
fresh: bool = False,
341+
refresh_token: bool = False,
342+
type: str | None = None,
343+
base64_encoded: bool = False,
344+
location: str | Sequence[str] | None = None,
345+
token_key: str | None = None,
346+
token_prefix: str | None = None,
347+
token: str | bool | None = None,
348+
) -> None:
349+
"""Validate the current request token against the endpoint requirements."""
350+
351+
validate_required_token_flags(fresh=fresh, refresh_token=refresh_token)
352+
if token:
353+
self._token = token
354+
else:
355+
location = location or self._token_location
356+
match True:
357+
case _ if "headers" in location:
358+
self._token = self._get_paseto_from_header(
359+
header_name=token_key or self._header_name,
360+
header_type=token_prefix or self._header_type,
361+
)
362+
case _ if "json" in location:
363+
self._token = self._get_paseto_from_json(
364+
json_key=token_key or self._json_key,
365+
json_type=token_prefix or self._json_type,
366+
)
367+
368+
if not self._token:
369+
if optional:
370+
return None
371+
raise MissingTokenError(
372+
status_code=401,
373+
message="PASETO Authorization Token required",
374+
)
375+
376+
try:
377+
self._decode_token(base64_encoded=base64_encoded)
378+
except PASETODecodeError as err:
379+
if optional:
380+
return None
381+
raise err
382+
383+
payload = self.get_token_payload()
384+
validate_token_type(
385+
payload_type=payload["type"],
386+
refresh_token=refresh_token,
387+
token_type=type,
388+
)
389+
if fresh and not payload["fresh"]:
390+
raise FreshTokenRequired(
391+
status_code=401,
392+
message="PASETO access token is not fresh",
393+
)

0 commit comments

Comments
 (0)