Skip to content

Commit d8d82c4

Browse files
fix(auth): isolate request state and harden token parsing
1 parent 1c6ba4d commit d8d82c4

9 files changed

Lines changed: 193 additions & 45 deletions

File tree

src/fastapi_paseto/_internal/auth.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,22 @@ def __init__(
5757
) -> None:
5858
"""Capture request-scoped objects used by token extraction helpers."""
5959

60+
self._token: str | None = None
61+
self._token_parts: list[str] = []
62+
self._current_user: str | int | None = None
63+
self._decoded_token: Token | None = None
6064
self._request_json = request_json
6165
self._response = response
62-
if request is not None:
63-
self._request = request
64-
if websocket is not None:
65-
self._websocket = websocket
66+
self._request = request
67+
self._websocket = websocket
68+
69+
def _reset_runtime_state(self) -> None:
70+
"""Clear request-scoped authentication state."""
71+
72+
self._token = None
73+
self._token_parts = []
74+
self._current_user = None
75+
self._decoded_token = None
6676

6777
def _get_paseto_from_json(
6878
self,
@@ -106,8 +116,10 @@ def _get_paseto_from_query(
106116
def _get_connection(self) -> Request | WebSocket:
107117
"""Return the current request or websocket connection."""
108118

109-
if hasattr(self, "_websocket"):
119+
if self._websocket is not None:
110120
return self._websocket
121+
if self._request is None: # pragma: no cover
122+
raise RuntimeError("Request or websocket connection is required")
111123
return self._request
112124

113125
def _get_connection_headers(self) -> Mapping[str, str]:
@@ -123,7 +135,7 @@ def _get_connection_query_params(self) -> Mapping[str, str]:
123135
def _is_websocket_connection(self) -> bool:
124136
"""Return whether the current dependency context is a websocket."""
125137

126-
return hasattr(self, "_websocket")
138+
return self._websocket is not None
127139

128140
def _get_paseto_identifier(self) -> str:
129141
"""Return a new unique token identifier."""
@@ -393,6 +405,7 @@ def paseto_required(
393405
) -> None:
394406
"""Validate the current request token against the endpoint requirements."""
395407

408+
self._reset_runtime_state()
396409
try:
397410
validate_required_token_flags(fresh=fresh, refresh_token=refresh_token)
398411
if token:

src/fastapi_paseto/_internal/request.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
from fastapi_paseto.exceptions import InvalidHeaderError
88

99

10+
def _build_token_error_message(token_key: str, token_type: str | None) -> str:
11+
"""Return the error message for a malformed token-bearing field."""
12+
13+
if token_type:
14+
return f"Bad {token_key} header. Expected value '{token_type} <PASETO>'"
15+
return f"Bad {token_key} header. Expected value '<PASETO>'"
16+
17+
1018
async def get_request_json(
1119
request: Request = None,
1220
websocket: WebSocket = None,
@@ -35,21 +43,27 @@ def extract_token_from_json(
3543
return None
3644

3745
token = request_json[json_key]
46+
if not isinstance(token, str):
47+
raise InvalidHeaderError(
48+
status_code=422,
49+
message=_build_token_error_message(json_key, json_type),
50+
)
51+
3852
if not json_type:
3953
if not token:
4054
raise InvalidHeaderError(
4155
status_code=422,
42-
message=f"Bad {json_key} header. Excepted value 'Bearer <PASETO>'",
56+
message=_build_token_error_message(json_key, json_type),
4357
)
4458
return token
4559

46-
token_prefix, token = token.split()
47-
if not token or token_prefix != json_type:
60+
parts = token.split()
61+
if len(parts) != 2 or parts[0] != json_type or not parts[1]:
4862
raise InvalidHeaderError(
4963
status_code=422,
50-
message=f"Bad {json_key} header. Expected value '{json_type} <PASETO>'",
64+
message=_build_token_error_message(json_key, json_type),
5165
)
52-
return token
66+
return parts[1]
5367

5468

5569
def extract_token_from_header(
@@ -67,11 +81,11 @@ def extract_token_from_header(
6781
if len(parts) != 1:
6882
raise InvalidHeaderError(
6983
status_code=422,
70-
message=f"Bad {header_name} header. Excepted value '<PASETO>'",
84+
message=f"Bad {header_name} header. Expected value '<PASETO>'",
7185
)
7286
return parts[0]
7387

74-
if not parts[0].__contains__(header_type) or len(parts) != 2:
88+
if len(parts) != 2 or parts[0].lower() != header_type.lower() or not parts[1]:
7589
raise InvalidHeaderError(
7690
status_code=422,
7791
message=f"Bad {header_name} header. Expected value '{header_type} <PASETO>'",

src/fastapi_paseto/auth_config.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
"""Shared mutable configuration state for ``AuthPASETO``."""
1+
"""Shared configuration state for ``AuthPASETO``."""
22

33
from collections.abc import Callable, Mapping
44
from datetime import timedelta
55

66
from pydantic import ValidationError
77
from pydantic_settings import BaseSettings
8-
from pyseto import Token
98

109
from fastapi_paseto.config import LoadConfig
1110

@@ -37,13 +36,8 @@
3736

3837
class AuthConfig:
3938
"""Hold class-level configuration shared by all ``AuthPASETO`` instances."""
40-
41-
_token: str | None = None
42-
_token_parts: list[str] = []
4339
_token_location: list[str] | tuple[str, ...] = ("headers",)
4440
_websocket_token_location: list[str] | tuple[str, ...] = ("headers",)
45-
_current_user: str | int | None = None
46-
_decoded_token: Token | None = None
4741

4842
_secret_key: str | None = None
4943
_public_key: str | None = None

tests/conftest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77

88

99
_AUTHPASETO_DEFAULTS: dict[str, object] = {
10-
"_token": None,
11-
"_token_parts": [],
1210
"_token_location": ("headers",),
1311
"_websocket_token_location": ("headers",),
14-
"_current_user": None,
15-
"_decoded_token": None,
1612
"_secret_key": None,
1713
"_public_key": None,
1814
"_private_key": None,

tests/test_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ def protected(Authorize: AuthPASETO = Depends()):
2727

2828

2929
def test_default_config():
30-
assert AuthPASETO._token is None
30+
authorize = AuthPASETO()
31+
32+
assert authorize._token is None
3133
assert AuthPASETO._token_location == ("headers",)
3234
assert AuthPASETO._websocket_token_location == ("headers",)
33-
assert AuthPASETO._current_user is None
34-
assert AuthPASETO._decoded_token is None
35-
assert AuthPASETO._secret_key is None
35+
assert authorize._current_user is None
36+
assert authorize._decoded_token is None
37+
assert AuthPASETO._secret_key is None
3638
assert AuthPASETO._public_key is None
3739
assert AuthPASETO._private_key is None
3840
assert AuthPASETO._purpose == "local"

tests/test_headers.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def test_header_without_paseto(client):
4040
}
4141

4242

43-
def test_header_without_bearer(client):
44-
response = client.get("/protected", headers={"Authorization": "Test asd"})
45-
assert response.status_code == 422
46-
assert response.json() == {
47-
"detail": "Bad Authorization header. Expected value 'Bearer <PASETO>'"
43+
def test_header_without_bearer(client):
44+
response = client.get("/protected", headers={"Authorization": "Test asd"})
45+
assert response.status_code == 422
46+
assert response.json() == {
47+
"detail": "Bad Authorization header. Expected value 'Bearer <PASETO>'"
4848
}
4949

5050
response = client.get("/protected", headers={"Authorization": "Test "})
@@ -56,10 +56,30 @@ def test_header_without_bearer(client):
5656

5757
def test_header_invalid_paseto(client):
5858
response = client.get("/protected", headers={"Authorization": "Bearer asd"})
59-
assert response.status_code == 422
60-
assert response.json() == {"detail": "Invalid PASETO format"}
61-
62-
59+
assert response.status_code == 422
60+
assert response.json() == {"detail": "Invalid PASETO format"}
61+
62+
63+
def test_header_requires_exact_bearer_scheme(client, Authorize):
64+
token = Authorize.create_access_token(subject="test")
65+
66+
invalid_scheme_response = client.get(
67+
"/protected",
68+
headers={"Authorization": f"NotBearer {token}"},
69+
)
70+
assert invalid_scheme_response.status_code == 422
71+
assert invalid_scheme_response.json() == {
72+
"detail": "Bad Authorization header. Expected value 'Bearer <PASETO>'"
73+
}
74+
75+
lowercase_scheme_response = client.get(
76+
"/protected",
77+
headers={"Authorization": f"bearer {token}"},
78+
)
79+
assert lowercase_scheme_response.status_code == 200
80+
assert lowercase_scheme_response.json() == {"hello": "world"}
81+
82+
6383
def test_valid_header(client, Authorize):
6484
token = Authorize.create_access_token(subject="test")
6585
response = client.get("/protected", headers={"Authorization": f"Bearer {token}"})

tests/test_json.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,35 @@ def get_settings_one():
6060
assert response.json() == {"detail": "Invalid PASETO format"}
6161

6262

63+
@pytest.mark.parametrize(
64+
("payload", "detail"),
65+
[
66+
(
67+
{"access_token": 123},
68+
"Bad access_token header. Expected value 'Bearer <PASETO>'",
69+
),
70+
(
71+
{"access_token": ""},
72+
"Bad access_token header. Expected value 'Bearer <PASETO>'",
73+
),
74+
(
75+
{"access_token": "Bearer"},
76+
"Bad access_token header. Expected value 'Bearer <PASETO>'",
77+
),
78+
(
79+
{"access_token": "Bearer token extra"},
80+
"Bad access_token header. Expected value 'Bearer <PASETO>'",
81+
),
82+
],
83+
)
84+
def test_malformed_json_token_returns_auth_error(client, configure_auth, payload, detail):
85+
configure_auth(authpaseto_json_type="Bearer")
86+
87+
response = client.request("GET", "/protected_json", json=payload)
88+
assert response.status_code == 422
89+
assert response.json() == {"detail": detail}
90+
91+
6392
def test_wrong_location(client, access_token):
6493
class SettingsOne(BaseSettings):
6594
authpaseto_token_location: list[str] = ["headers"]

tests/test_url_protected.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,52 @@ def test_paseto_required(client, Authorize):
7979
assert response.json() == {"hello": "world"}
8080

8181

82-
def test_paseto_optional(client, Authorize):
83-
url = "/paseto-optional"
84-
# if header not define return anonym user
85-
response = client.get(url)
86-
assert response.status_code == 200
82+
def test_paseto_optional(client, Authorize):
83+
url = "/paseto-optional"
84+
# if header not define return anonym user
85+
response = client.get(url)
86+
assert response.status_code == 200
8787
assert response.json() == {"hello": "anonym"}
8888

8989
token = Authorize.create_access_token(subject="test")
90-
response = client.get(url, headers={"Authorization": f"Bearer {token}"})
91-
assert response.status_code == 200
92-
assert response.json() == {"hello": "world"}
90+
response = client.get(url, headers={"Authorization": f"Bearer {token}"})
91+
assert response.status_code == 200
92+
assert response.json() == {"hello": "world"}
93+
94+
95+
def test_paseto_optional_does_not_leak_previous_subject(client, Authorize):
96+
url = "/paseto-optional"
97+
token = Authorize.create_access_token(subject="test")
98+
99+
authorized_response = client.get(
100+
url,
101+
headers={"Authorization": f"Bearer {token}"},
102+
)
103+
assert authorized_response.status_code == 200
104+
assert authorized_response.json() == {"hello": "world"}
105+
106+
anonymous_response = client.get(url)
107+
assert anonymous_response.status_code == 200
108+
assert anonymous_response.json() == {"hello": "anonym"}
109+
110+
111+
def test_paseto_optional_invalid_token_does_not_leak_previous_subject(client, Authorize):
112+
url = "/paseto-optional"
113+
token = Authorize.create_access_token(subject="test")
114+
115+
authorized_response = client.get(
116+
url,
117+
headers={"Authorization": f"Bearer {token}"},
118+
)
119+
assert authorized_response.status_code == 200
120+
assert authorized_response.json() == {"hello": "world"}
121+
122+
invalid_response = client.get(
123+
url,
124+
headers={"Authorization": "Bearer invalid"},
125+
)
126+
assert invalid_response.status_code == 200
127+
assert invalid_response.json() == {"hello": "anonym"}
93128

94129

95130
def test_refresh_required(client, Authorize):

tests/test_websocket.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,51 @@ def test_websocket_optional_without_token(make_client: Callable[..., TestClient]
203203
assert message == {"payload": None, "subject": None}
204204

205205

206+
def test_websocket_optional_does_not_leak_previous_subject(
207+
make_client: Callable[..., TestClient],
208+
Authorize: AuthPASETO,
209+
) -> None:
210+
client = make_client()
211+
token = Authorize.create_access_token(subject="test")
212+
213+
with client.websocket_connect(
214+
"/required",
215+
headers={"Authorization": f"Bearer {token}"},
216+
) as websocket:
217+
message = websocket.receive_json()
218+
219+
assert message["subject"] == "test"
220+
221+
with client.websocket_connect("/optional") as websocket:
222+
optional_message = websocket.receive_json()
223+
224+
assert optional_message == {"payload": None, "subject": None}
225+
226+
227+
def test_websocket_optional_invalid_token_does_not_leak_previous_subject(
228+
make_client: Callable[..., TestClient],
229+
Authorize: AuthPASETO,
230+
) -> None:
231+
client = make_client()
232+
token = Authorize.create_access_token(subject="test")
233+
234+
with client.websocket_connect(
235+
"/required",
236+
headers={"Authorization": f"Bearer {token}"},
237+
) as websocket:
238+
message = websocket.receive_json()
239+
240+
assert message["subject"] == "test"
241+
242+
with client.websocket_connect(
243+
"/optional",
244+
headers={"Authorization": "Bearer invalid"},
245+
) as websocket:
246+
optional_message = websocket.receive_json()
247+
248+
assert optional_message == {"payload": None, "subject": None}
249+
250+
206251
def test_websocket_refresh_required(
207252
make_client: Callable[..., TestClient],
208253
Authorize: AuthPASETO,

0 commit comments

Comments
 (0)