Skip to content

Commit 3b5b8a2

Browse files
committed
refactor: Move auth extractors into authentication module
And split the auth check into two so that other methods can access the raw bearer token if required.
1 parent ab2ebc6 commit 3b5b8a2

3 files changed

Lines changed: 112 additions & 34 deletions

File tree

src/blueapi/service/authentication.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66
import time
77
import webbrowser
88
from abc import ABC, abstractmethod
9+
from collections.abc import Mapping
910
from functools import cached_property
1011
from http import HTTPStatus
1112
from pathlib import Path
12-
from typing import Any, cast
13+
from typing import Annotated, Any, cast
1314

1415
import httpx
1516
import jwt
1617
import requests
18+
from fastapi import Depends, HTTPException, Request
19+
from fastapi.security.utils import get_authorization_scheme_param
1720
from pydantic import TypeAdapter
1821
from requests.auth import AuthBase
22+
from starlette.status import HTTP_401_UNAUTHORIZED
1923

2024
from blueapi.config import OIDCConfig, ServiceAccount
2125
from blueapi.service.model import Cache
@@ -272,3 +276,52 @@ def get_access_token(self):
272276
def sync_auth_flow(self, request):
273277
request.headers["Authorization"] = f"Bearer {self.get_access_token()}"
274278
yield request
279+
280+
281+
def unchecked_bearer_token(req: Request) -> str | None:
282+
"""Get bearer token value from authorization header"""
283+
auth = req.headers.get("Authorization")
284+
scheme, param = get_authorization_scheme_param(auth)
285+
if scheme.casefold() != "bearer":
286+
return None
287+
return param.strip()
288+
289+
290+
UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)]
291+
292+
293+
def build_access_token_check(config: OIDCConfig):
294+
"""
295+
Create a function to validate the bearer token of requests
296+
297+
The returned function should be used via fastAPI's 'Depends' mechanism to
298+
ensure users are authenticated
299+
"""
300+
jwkclient = jwt.PyJWKClient(config.jwks_uri)
301+
302+
def validate_bearer_token(request: Request, token: UncheckedBearerToken):
303+
"""Check that a bearer token is valid and inject into request state"""
304+
if not token:
305+
raise HTTPException(
306+
status_code=HTTP_401_UNAUTHORIZED,
307+
detail="Not authenticated",
308+
headers={"WWW-Authenticate": "Bearer"},
309+
)
310+
311+
signing_key = jwkclient.get_signing_key_from_jwt(token)
312+
decoded: dict[str, Any] = jwt.decode(
313+
token,
314+
signing_key.key,
315+
algorithms=config.id_token_signing_alg_values_supported,
316+
verify=True,
317+
audience=config.client_audience,
318+
issuer=config.issuer,
319+
)
320+
request.state.decoded_access_token = decoded
321+
322+
return validate_bearer_token
323+
324+
325+
def access_token(request: Request) -> Mapping[str, Any] | None:
326+
"""Get the decoded and verified access token of the user making the request"""
327+
return getattr(request.state, "decoded_access_token", None)

src/blueapi/service/main.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from fastapi.datastructures import Address
2020
from fastapi.middleware.cors import CORSMiddleware
2121
from fastapi.responses import RedirectResponse, StreamingResponse
22-
from fastapi.security import OAuth2AuthorizationCodeBearer
2322
from observability_utils.tracing import (
2423
add_span_attributes,
2524
get_tracer,
@@ -37,6 +36,7 @@
3736
from blueapi import __version__
3837
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
3938
from blueapi.service import interface
39+
from blueapi.service.authentication import build_access_token_check
4040
from blueapi.worker import TrackableTask, WorkerState
4141
from blueapi.worker.event import TaskStatusEnum
4242

@@ -61,6 +61,7 @@
6161
RUNNER: WorkerDispatcher | None = None
6262

6363
LOGGER = logging.getLogger(__name__)
64+
TRACER = get_tracer("interface")
6465

6566

6667
def _runner() -> WorkerDispatcher:
@@ -117,7 +118,7 @@ def get_app(config: ApplicationConfig):
117118
)
118119
dependencies = []
119120
if config.oidc:
120-
dependencies.append(Depends(decode_access_token(config.oidc)))
121+
dependencies.append(Depends(build_access_token_check(config.oidc)))
121122
app.swagger_ui_init_oauth = {
122123
"clientId": "NOT_SUPPORTED",
123124
}
@@ -140,32 +141,6 @@ def get_app(config: ApplicationConfig):
140141
return app
141142

142143

143-
def decode_access_token(config: OIDCConfig):
144-
jwkclient = jwt.PyJWKClient(config.jwks_uri)
145-
oauth_scheme = OAuth2AuthorizationCodeBearer(
146-
authorizationUrl=config.authorization_endpoint,
147-
tokenUrl=config.token_endpoint,
148-
refreshUrl=config.token_endpoint,
149-
)
150-
151-
def inner(request: Request, access_token: str = Depends(oauth_scheme)):
152-
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
153-
decoded: dict[str, Any] = jwt.decode(
154-
access_token,
155-
signing_key.key,
156-
algorithms=config.id_token_signing_alg_values_supported,
157-
verify=True,
158-
audience=config.client_audience,
159-
issuer=config.issuer,
160-
)
161-
request.state.decoded_access_token = decoded
162-
163-
return inner
164-
165-
166-
TRACER = get_tracer("interface")
167-
168-
169144
async def on_key_error_404(_: Request, __: Exception):
170145
return JSONResponse(
171146
status_code=status.HTTP_404_NOT_FOUND,

tests/unit_tests/service/test_authentication.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
import pytest
99
import responses
1010
import respx
11+
from fastapi import HTTPException
1112
from pydantic import SecretStr
1213
from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN
1314

1415
from blueapi.config import OIDCConfig, ServiceAccount
15-
from blueapi.service import main
16+
from blueapi.service import authentication
1617
from blueapi.service.authentication import (
1718
SessionCacheManager,
1819
SessionManager,
1920
TiledAuth,
21+
access_token,
22+
build_access_token_check,
23+
unchecked_bearer_token,
2024
)
2125

2226

@@ -124,18 +128,18 @@ def test_poll_for_token_timeout(
124128
def test_server_raises_exception_for_invalid_token(
125129
oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock
126130
):
127-
inner = main.decode_access_token(oidc_config)
131+
inner = authentication.build_access_token_check(oidc_config)
128132
with pytest.raises(jwt.PyJWTError):
129-
inner(Mock(), access_token="Invalid Token")
133+
inner(Mock(), token="Invalid Token")
130134

131135

132136
def test_processes_valid_token(
133137
oidc_config: OIDCConfig,
134138
mock_authn_server: responses.RequestsMock,
135139
valid_token_with_jwt,
136140
):
137-
inner = main.decode_access_token(oidc_config)
138-
inner(Mock(), access_token=valid_token_with_jwt["access_token"])
141+
inner = authentication.build_access_token_check(oidc_config)
142+
inner(Mock(), token=valid_token_with_jwt["access_token"])
139143

140144

141145
def test_session_cache_manager_returns_writable_file_path(tmp_path):
@@ -182,3 +186,49 @@ def test_tiled_auth_sync_auth_flow():
182186
result = next(flow)
183187

184188
assert result.headers["Authorization"] == f"Bearer {access_token}"
189+
190+
191+
@pytest.mark.parametrize(
192+
"header,token",
193+
[
194+
(None, None),
195+
("ApiKey foobar", None),
196+
("Bearer foobar", "foobar"),
197+
("Bearer with_whitespace ", "with_whitespace"),
198+
("Bearerfoobar", None),
199+
],
200+
)
201+
def test_unchecked_bearer_token(header: str | None, token: str | None):
202+
req = Mock()
203+
req.headers.get.side_effect = lambda key: header if key == "Authorization" else None
204+
205+
assert unchecked_bearer_token(req) == token
206+
207+
208+
def test_access_token():
209+
req = Mock()
210+
req.state.decoded_access_token = {"foo": "bar"}
211+
212+
assert access_token(req) == {"foo": "bar"}
213+
214+
215+
def test_access_token_without_token():
216+
req = Mock()
217+
del req.state.decoded_access_token
218+
219+
assert access_token(req) is None
220+
221+
222+
@patch("blueapi.service.authentication.jwt")
223+
def test_build_access_token(mock_jwt: Mock):
224+
# Return None when building client to ensure no field/method access
225+
mock_jwt.PyJWKClient.return_value = None
226+
oidc_config = Mock()
227+
req = Mock()
228+
229+
validate_fn = build_access_token_check(oidc_config)
230+
231+
with pytest.raises(HTTPException, match="401"):
232+
validate_fn(req, token=None)
233+
234+
mock_jwt.decode.assert_not_called()

0 commit comments

Comments
 (0)