Skip to content

Commit 17d523a

Browse files
committed
fix(auth): Use token in AuthTuple
While implementing authz I've made the authentication module put the claims in the `AuthTuple` instead of putting the original token because I assumed the original token was unnecessary and I didn't want to make the authz module re-parse the token to get the claims. That was a mistake, because we need to pass on the user token to the MCP server for it to validate the user's access. So this commit changes the `AuthTuple` to contain the original token instead of the claims. The authz module has been updated to parse the claims from the token instead of receiving them as a JSON string. Somewhat hackily, since we don't want to re-verify the token signature, so we assume that the token has already been verified during authentication and just decode the claims from the middle section of the JWT token.
1 parent 3e2d883 commit 17d523a

4 files changed

Lines changed: 34 additions & 18 deletions

File tree

src/auth/jwk_token.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""
22

33
import logging
4-
import json
54
from asyncio import Lock
65
from typing import Any, Callable
76

@@ -191,4 +190,4 @@ async def __call__(self, request: Request) -> AuthTuple:
191190

192191
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
193192

194-
return user_id, username, json.dumps(claims)
193+
return user_id, username, user_token

src/authorization/resolvers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Authorization resolvers for role evaluation and access control."""
22

33
from abc import ABC, abstractmethod
4-
import json
54
import logging
5+
import base64
6+
import json
67
from typing import Any
78

89
from jsonpath_ng import parse
@@ -38,6 +39,17 @@ async def resolve_roles(self, auth: AuthTuple) -> UserRoles:
3839
return set()
3940

4041

42+
def unsafe_get_claims(token: str) -> dict[str, Any]:
43+
"""Get claims from a token without validating the signature.
44+
45+
A somewhat hacky way to get JWT claims without verifying the signature.
46+
We assume verification has already been done during authentication.
47+
"""
48+
payload = token.split(".")[1]
49+
padded = payload + "=" * (-len(payload) % 4)
50+
return json.loads(base64.urlsafe_b64decode(padded))
51+
52+
4153
class JwtRolesResolver(RolesResolver): # pylint: disable=too-few-public-methods
4254
"""Processes JWT claims with the given JSONPath rules to get roles."""
4355

@@ -76,7 +88,7 @@ def _get_claims(auth: AuthTuple) -> dict[str, Any]:
7688
# No claims for guests
7789
return {}
7890

79-
jwt_claims = json.loads(token)
91+
jwt_claims = unsafe_get_claims(token)
8092

8193
if not jwt_claims:
8294
raise RoleResolutionError(

tests/unit/auth/test_jwk_token.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
"""Unit tests for functions defined in auth/jwk_token.py"""
44

55
import time
6-
import json
7-
import base64
86

97
import pytest
108
from fastapi import HTTPException, Request
@@ -174,19 +172,12 @@ def set_auth_header(request: Request, token: str):
174172
request.scope["headers"] = new_headers
175173

176174

177-
def get_claims(token: str):
178-
"""Extract claims from a JWT token without validating it."""
179-
payload = token.split(".")[1]
180-
padded = payload + "=" * (-len(payload) % 4)
181-
return json.loads(base64.urlsafe_b64decode(padded))
182-
183-
184-
def ensure_test_user_id_and_name(auth_tuple, token):
175+
def ensure_test_user_id_and_name(auth_tuple, expected_token):
185176
"""Utility to ensure that the values in the auth tuple match the test values."""
186-
user_id, username, token_claims = auth_tuple
177+
user_id, username, token = auth_tuple
187178
assert user_id == TEST_USER_ID
188179
assert username == TEST_USER_NAME
189-
assert json.loads(token_claims) == get_claims(token)
180+
assert token == expected_token
190181

191182

192183
async def test_valid(

tests/unit/authorization/test_resolvers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
"""Unit tests for the authorization resolvers."""
22

3+
import json
4+
import base64
5+
36
from authorization.resolvers import JwtRolesResolver, GenericAccessResolver
47
from models.config import JwtRoleRule, AccessRule, JsonPathOperator, Action
58

69

10+
def claims_to_token(claims: dict) -> str:
11+
"""Convert JWT claims dictionary to a JSON string token."""
12+
13+
string_claims = json.dumps(claims)
14+
b64_encoded_claims = (
15+
base64.urlsafe_b64encode(string_claims.encode()).decode().rstrip("=")
16+
)
17+
18+
return f"foo_header.{b64_encoded_claims}.foo_signature"
19+
20+
721
class TestJwtRolesResolver:
822
"""Test cases for JwtRolesResolver."""
923

@@ -33,7 +47,7 @@ async def test_resolve_roles_redhat_employee(self):
3347
}
3448

3549
# Mock auth tuple with JWT claims as third element
36-
auth = ("user", "token", str(jwt_claims).replace("'", '"'))
50+
auth = ("user", "token", claims_to_token(jwt_claims))
3751
roles = await jwt_resolver.resolve_roles(auth)
3852
assert "employee" in roles
3953

@@ -57,7 +71,7 @@ async def test_resolve_roles_no_match(self):
5771
}
5872

5973
# Mock auth tuple with JWT claims as third element
60-
auth = ("user", "token", str(jwt_claims).replace("'", '"'))
74+
auth = ("user", "token", claims_to_token(jwt_claims))
6175
roles = await jwt_resolver.resolve_roles(auth)
6276
assert len(roles) == 0
6377

0 commit comments

Comments
 (0)