Skip to content
This repository was archived by the owner on Jun 25, 2025. It is now read-only.

Commit 3d9f222

Browse files
Merge pull request #2 from code-specialist/auth-middleware-error-handling
proper error handling in authentication middleware
2 parents db8521f + 7a6893a commit 3d9f222

2 files changed

Lines changed: 51 additions & 17 deletions

File tree

fastapi_auth_middleware/middleware.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Tuple
1+
from typing import Tuple, Callable, List
22

33
from fastapi import FastAPI
44
from starlette.authentication import AuthenticationBackend, AuthCredentials, AuthenticationError, BaseUser
55
from starlette.middleware.authentication import AuthenticationMiddleware
6-
from starlette.requests import HTTPConnection
6+
from starlette.requests import HTTPConnection, Request
7+
from starlette.responses import JSONResponse
78

89

910
class FastAPIUser(BaseUser):
@@ -44,7 +45,7 @@ def identity(self) -> str:
4445
class FastAPIAuthBackend(AuthenticationBackend):
4546
""" Auth Backend for FastAPI """
4647

47-
def __init__(self, verify_authorization_header: callable):
48+
def __init__(self, verify_authorization_header: Callable[[str], Tuple[List[str], BaseUser]]):
4849
""" Auth Backend constructor. Part of an AuthenticationMiddleware as backend.
4950
5051
Args:
@@ -64,19 +65,27 @@ async def authenticate(self, conn: HTTPConnection) -> Tuple[AuthCredentials, Bas
6465
if "Authorization" not in conn.headers:
6566
raise AuthenticationError("Authorization header missing")
6667

67-
authorization_header: str = conn.headers["Authorization"]
68-
scopes, user = self.verify_authorization_header(authorization_header)
68+
try:
69+
authorization_header: str = conn.headers["Authorization"]
70+
scopes, user = self.verify_authorization_header(authorization_header)
71+
except Exception as exception:
72+
raise AuthenticationError(exception) from None
6973

7074
return AuthCredentials(scopes=scopes), user
7175

7276

73-
def AuthMiddleware(app: FastAPI, verify_authorization_header: callable):
77+
def AuthMiddleware(
78+
app: FastAPI,
79+
verify_authorization_header: Callable[[str], Tuple[List[str], BaseUser]],
80+
auth_error_handler: Callable[[Request, AuthenticationError], JSONResponse] = None
81+
):
7482
""" Factory method, returning an AuthenticationMiddleware
7583
Intentionally not named with lower snake case convention as this is a factory method returning a class. Should feel like a class.
7684
7785
Args:
7886
app (FastAPI): The FastAPI instance the middleware should be applied to. The `add_middleware` function of FastAPI adds the app as first argument by default.
79-
verify_authorization_header (callable): A function handle that returns a list of scopes and a BaseUser
87+
verify_authorization_header (Callable[[str], Tuple[List[str], BaseUser]]): A function handle that returns a list of scopes and a BaseUser
88+
auth_error_handler (Callable[[Request, Exception], JSONResponse]): Optional error handler for creating responses when an exception was raised in verify_authorization_header
8089
8190
Examples:
8291
```python
@@ -89,4 +98,4 @@ def verify_authorization_header(auth_header: str) -> Tuple[List[str], FastAPIUse
8998
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header)
9099
```
91100
"""
92-
return AuthenticationMiddleware(app, backend=FastAPIAuthBackend(verify_authorization_header))
101+
return AuthenticationMiddleware(app, backend=FastAPIAuthBackend(verify_authorization_header), on_error=auth_error_handler)

tests/test_basic.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import Callable
2+
13
from _pytest.fixtures import fixture
24
from fastapi import FastAPI
3-
from starlette.authentication import requires
5+
from starlette.authentication import requires, AuthenticationError
46
from starlette.requests import Request
7+
from starlette.responses import JSONResponse
58
from starlette.testclient import TestClient
69

710
from fastapi_auth_middleware import AuthMiddleware, FastAPIUser
@@ -20,10 +23,14 @@ def verify_authorization_header_basic_admin_scope(auth_header: str):
2023
return scopes, user
2124

2225

26+
def raise_exception_in_verify_authorization_header(_):
27+
raise Exception('some auth error occured')
28+
29+
2330
# Sample app with simple routes, takes a verify_authorization_header callable that is applied to the middleware
24-
def fastapi_app(verify_authorization_header: callable):
31+
def fastapi_app(verify_authorization_header: Callable, auth_error_handler: Callable = None):
2532
app = FastAPI()
26-
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header)
33+
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header, auth_error_handler=auth_error_handler)
2734

2835
@app.get("/")
2936
def home():
@@ -49,26 +56,44 @@ class TestBasicBehaviour:
4956
"""
5057

5158
@fixture
52-
def client(self):
59+
def client(self) -> TestClient:
5360
app = fastapi_app(verify_authorization_header_basic)
5461
return TestClient(app)
5562

5663
@fixture
57-
def client_with_scopes(self):
64+
def client_with_scopes(self) -> TestClient:
5865
app = fastapi_app(verify_authorization_header_basic_admin_scope)
5966
return TestClient(app)
6067

61-
def test_home_fail_no_header(self, client):
68+
def test_home_fail_no_header(self, client: TestClient):
6269
assert client.get("/").status_code == 400
6370

64-
def test_home_succeed(self, client):
71+
def test_home_succeed(self, client: TestClient):
6572
assert client.get("/", headers={"Authorization": "ey.."}).status_code == 200
6673

67-
def test_user_attributes(self, client):
74+
def test_user_attributes(self, client: TestClient):
6875
request = client.get("/user", headers={"Authorization": "ey.."})
6976
assert request.status_code == 200
7077
assert request.content == b'"True Code Specialist 1"' # b'"{user.is_authenticated} {user.display_name} {user.identity}"'
7178

72-
def test_scopes(self, client, client_with_scopes):
79+
def test_scopes(self, client: TestClient, client_with_scopes: TestClient):
7380
assert client.get("/admin-scope", headers={"Authorization": "ey.."}).status_code == 403 # Does not contain the requested scope
7481
assert client_with_scopes.get("/admin-scope", headers={"Authorization": "ey.."}).status_code == 200 # Contains the requested scope
82+
83+
def test_fail_auth_error(self):
84+
app = fastapi_app(verify_authorization_header=raise_exception_in_verify_authorization_header)
85+
client_with_auth_error = TestClient(app=app)
86+
87+
response = client_with_auth_error.get('/', headers={"Authorization": "ey.."})
88+
assert response.status_code == 400
89+
90+
def test_fail_auth_error_with_custom_handler(self):
91+
def handle_auth_error(request: Request, exception: AuthenticationError):
92+
assert isinstance(exception, AuthenticationError)
93+
return JSONResponse(content={'message': str(exception)}, status_code=401)
94+
95+
app = fastapi_app(verify_authorization_header=raise_exception_in_verify_authorization_header, auth_error_handler=handle_auth_error)
96+
client_with_auth_error = TestClient(app=app)
97+
98+
response = client_with_auth_error.get('/', headers={"Authorization": "ey.."})
99+
assert response.status_code == 401

0 commit comments

Comments
 (0)