Skip to content

Commit 73ffad5

Browse files
Enabling Entra Auth (#3251)
Co-authored-by: sahmed06 <58435820+sahmed06@users.noreply.github.com>
1 parent 6e681ab commit 73ffad5

16 files changed

Lines changed: 762 additions & 267 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ you will need to set up the variables for validating the token via cognito:
333333

334334
- `export COGNITO_AWS_REGION=eu-west-2` - This is unlikely to change
335335
- `export COGNITO_USER_POOL=eu-west-2_a123bc4DE` - Can be found be checking the `User pool ID` value for your environment on the [AWS console] (https://eu-west-2.console.aws.amazon.com/cognito/v2/idp/user-pools?region=eu-west-2)
336-
- `export COGNITO_JWT_AUTH_HEADER=HTTP_X_UHD_AUTH` - This is unlikely to change
336+
- `export JWT_AUTH_HEADER=HTTP_X_UHD_AUTH` - This is unlikely to change
337337

338338
---
339339

common/auth/cognito_jwt/user_manager.py

Lines changed: 0 additions & 37 deletions
This file was deleted.
Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
import jwt
34
from django.apps import apps as django_apps
45
from django.conf import settings
56
from django.utils.encoding import force_str
@@ -8,7 +9,7 @@
89
from rest_framework import HTTP_HEADER_ENCODING, exceptions
910
from rest_framework.authentication import BaseAuthentication
1011

11-
from .validator import TokenError, TokenValidator
12+
from .validator import CognitoTokenValidator, EntraTokenValidator, TokenError
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -22,8 +23,9 @@ def get_authorization_header(request):
2223
2324
Hide some test client ickyness where the header can be unicode.
2425
"""
25-
auth_header = getattr(settings, "COGNITO_JWT_AUTH_HEADER", "Authorization")
26+
auth_header = getattr(settings, "JWT_AUTH_HEADER", "HTTP_AUTHORIZATION")
2627
auth = request.META.get(auth_header, b"")
28+
2729
if isinstance(auth, str):
2830
# Work around django test client oddness
2931
auth = auth.encode(HTTP_HEADER_ENCODING)
@@ -44,18 +46,28 @@ def authenticate(self, request):
4446

4547
# Authenticate token
4648
try:
47-
token_validator = self.get_token_validator(request)
49+
token_validator, provider_name = self.get_token_validator(jwt_token)
50+
except TokenError as e:
51+
logger.debug("Failed to identify token provider: %s", e)
52+
raise exceptions.AuthenticationFailed(
53+
_("Unknown or malformed token issuer.")
54+
) from e
55+
56+
try:
4857
jwt_payload = token_validator.validate(jwt_token)
49-
except TokenError:
58+
except TokenError as e:
59+
logger.debug(
60+
"%s token validation failed: %s", provider_name.capitalize(), e
61+
)
5062
raise exceptions.AuthenticationFailed from None
5163

52-
custom_user_manager = self.get_custom_user_manager()
64+
custom_user_manager = self.get_custom_user_manager(provider_name)
5365

5466
if custom_user_manager:
55-
user = custom_user_manager.get_or_create_for_cognito(jwt_payload)
67+
user = custom_user_manager.get_or_create(jwt_payload)
5668
else:
5769
user_model = self.get_user_model()
58-
user = user_model.objects.get_or_create_for_cognito(jwt_payload)
70+
user = user_model.objects.get_or_create(jwt_payload)
5971

6072
if not user:
6173
logger.debug(
@@ -66,19 +78,27 @@ def authenticate(self, request):
6678
return (user, jwt_token)
6779

6880
@staticmethod
69-
def get_custom_user_manager():
70-
"""If COGNITO_USER_MANAGER is set, then the user object is obtained
71-
via get_or_create_for_cognito on the user manager, this allows use
81+
def get_custom_user_manager(provider="cognito"):
82+
"""If COGNITO_USER_MANAGER or ENTRA_USER_MANAGER is set, then the user object is obtained
83+
via get_or_create_for_cognito (or get_or_create_for_entra) on the user manager, this allows use
7284
of the default unmodified Django User model"""
7385
result = None
74-
custom_user_manager_path = getattr(settings, "COGNITO_USER_MANAGER", False)
86+
custom_user_manager_path = (
87+
getattr(settings, "ENTRA_USER_MANAGER", False)
88+
if provider == "entra"
89+
else getattr(settings, "COGNITO_USER_MANAGER", False)
90+
)
7591
if custom_user_manager_path:
7692
result = import_string(custom_user_manager_path)()
7793
return result
7894

7995
@staticmethod
80-
def get_user_model():
81-
user_model = getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL)
96+
def get_user_model(provider="cognito"):
97+
user_model = (
98+
getattr(settings, "ENTRA_USER_MODEL", settings.AUTH_USER_MODEL)
99+
if provider == "entra"
100+
else getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL)
101+
)
82102
return django_apps.get_model(user_model, require_ready=False)
83103

84104
@staticmethod
@@ -100,12 +120,33 @@ def get_jwt_token(request):
100120
return auth[1]
101121

102122
@staticmethod
103-
def get_token_validator(request):
104-
return TokenValidator(
105-
settings.COGNITO_AWS_REGION,
106-
settings.COGNITO_USER_POOL,
107-
settings.COGNITO_AUDIENCE,
108-
)
123+
def get_token_validator(jwt_token):
124+
try:
125+
# Decode without verifying signature just to read the header/payload
126+
unverified_payload = jwt.decode(
127+
jwt_token, options={"verify_signature": False} # noqa: S5659
128+
)
129+
issuer = unverified_payload.get("iss", "")
130+
except jwt.PyJWTError as e:
131+
raise exceptions.AuthenticationFailed(_("Malformed JWT.")) from e
132+
133+
if "cognito-idp" in issuer:
134+
validator = CognitoTokenValidator(
135+
settings.COGNITO_AWS_REGION,
136+
settings.COGNITO_USER_POOL,
137+
settings.COGNITO_AUDIENCE,
138+
)
139+
return validator, "cognito"
140+
141+
if "sts.windows.net" in issuer:
142+
validator = EntraTokenValidator(
143+
settings.ENTRA_TENANT_ID,
144+
settings.ENTRA_AUDIENCE,
145+
settings.ENTRA_ALLOWED_APP_IDS,
146+
)
147+
return validator, "entra"
148+
149+
raise exceptions.AuthenticationFailed(_("Invalid or unsupported token issuer."))
109150

110151
@staticmethod
111152
def authenticate_header(request):

common/auth/jwt/user_manager.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
3+
from django.contrib.auth import get_user_model
4+
from django.contrib.auth.models import BaseUserManager
5+
from rest_framework import exceptions
6+
7+
from cms.auth_content.models.users import User
8+
from metrics.data.managers.rbac_models.user import UserManager
9+
from metrics.utils.permission_hierarchy import build_permission_hierarchy
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def get_user_permission_set(user_id: str):
15+
permissions = UserManager.get_permission_sets_for_user(user_id)
16+
return build_permission_hierarchy(permissions)
17+
18+
19+
class CognitoManager(BaseUserManager):
20+
21+
@staticmethod
22+
def get_or_create(jwt_payload):
23+
"""Create an ephemeral user instance for this request.
24+
If the permissions aren't present in the JWT, queries for them in
25+
the database based on the entraObjectId in the token
26+
"""
27+
try:
28+
username = jwt_payload["entraObjectId"]
29+
# Check if the JWT already includes permissionSets
30+
# Use if found, if not grab user permissions from the database
31+
if "permissionSets" in jwt_payload and jwt_payload["permissionSets"] != []:
32+
permission_sets = jwt_payload["permissionSets"]
33+
else:
34+
permission_sets = get_user_permission_set(username)
35+
except KeyError:
36+
logger.debug(
37+
"Error getting entraObjectId and/or permissionSets field(s)"
38+
" from jwt payload: '%s'",
39+
jwt_payload,
40+
)
41+
return None
42+
43+
user_class = get_user_model()
44+
user = user_class(username=username)
45+
user.permission_sets = permission_sets
46+
return user
47+
48+
49+
class EntraManager(BaseUserManager):
50+
51+
@staticmethod
52+
def get_or_create(jwt_payload):
53+
"""Create an ephemeral user instance for this request.
54+
If the provided appid isn't present in the database, raises
55+
AuthenticationFailed exception
56+
"""
57+
try:
58+
username = jwt_payload["appid"]
59+
if not User.objects.filter(user_id=username).exists():
60+
msg = "Application not found."
61+
raise exceptions.AuthenticationFailed(msg)
62+
permission_sets = get_user_permission_set(username)
63+
except KeyError:
64+
logger.info(
65+
"Error getting entraObjectId and/or permissionSets field(s)"
66+
" from jwt payload: '%s'",
67+
jwt_payload,
68+
)
69+
return None
70+
71+
user_class = get_user_model()
72+
user = user_class(username=username)
73+
user.permission_sets = permission_sets
74+
return user
Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TokenError(Exception):
1515
pass
1616

1717

18-
class TokenValidator:
18+
class CognitoTokenValidator:
1919
def __init__(self, aws_region, aws_user_pool, audience):
2020
self.aws_region = aws_region
2121
self.aws_user_pool = aws_user_pool
@@ -86,3 +86,78 @@ def validate(self, token):
8686
) as exc:
8787
raise TokenError(str(exc)) from exc
8888
return jwt_data
89+
90+
91+
class EntraTokenValidator:
92+
def __init__(self, tenant_id, audience, allowed_app_ids):
93+
self.tenant_id = tenant_id
94+
self.audience = audience
95+
self.allowed_app_ids = allowed_app_ids
96+
self.jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
97+
98+
@cached_property
99+
def expected_issuer(self):
100+
return f"https://sts.windows.net/{self.tenant_id}/"
101+
102+
@cached_property
103+
def _json_web_keys(self):
104+
response = requests.get(self.jwks_url, timeout=10)
105+
response.raise_for_status()
106+
json_data = response.json()
107+
return {item["kid"]: json.dumps(item) for item in json_data["keys"]}
108+
109+
def _get_public_key(self, token):
110+
try:
111+
headers = jwt.get_unverified_header(token)
112+
except jwt.DecodeError as exc:
113+
raise TokenError(str(exc)) from exc
114+
115+
if getattr(settings, "ENTRA_PUBLIC_KEYS_CACHING_ENABLED", False):
116+
cache_key = "entra_jwt:{}".format(headers["kid"])
117+
jwk_data = cache.get(cache_key)
118+
119+
if not jwk_data:
120+
jwk_data = self._json_web_keys.get(headers["kid"])
121+
timeout = getattr(settings, "ENTRA_PUBLIC_KEYS_CACHING_TIMEOUT", 300)
122+
cache.set(cache_key, jwk_data, timeout=timeout)
123+
else:
124+
jwk_data = self._json_web_keys.get(headers["kid"])
125+
126+
if jwk_data:
127+
return RSAAlgorithm.from_jwk(jwk_data)
128+
return None
129+
130+
def validate(self, token):
131+
public_key = self._get_public_key(token)
132+
if not public_key:
133+
msg = "No key found for this token"
134+
raise TokenError(msg)
135+
136+
params = {
137+
"jwt": token,
138+
"key": public_key,
139+
"issuer": self.expected_issuer,
140+
"audience": self.audience,
141+
"algorithms": ["RS256"],
142+
}
143+
144+
try:
145+
payload = jwt.decode(**params)
146+
except (
147+
jwt.InvalidTokenError,
148+
jwt.ExpiredSignatureError,
149+
jwt.DecodeError,
150+
) as exc:
151+
raise TokenError(str(exc)) from exc
152+
153+
roles = payload.get("roles", [])
154+
if "application.read" not in roles:
155+
msg = "Missing required role: application.read"
156+
raise TokenError(msg)
157+
158+
app_id_claim = payload.get("appid") or payload.get("azp")
159+
if app_id_claim not in self.allowed_app_ids:
160+
msg = "Invalid app_id claim"
161+
raise TokenError(msg)
162+
163+
return payload

config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,18 @@
6161
# The name of the AWS profile to use for the AWS client used for ingestion
6262
AWS_PROFILE_NAME = os.environ.get("AWS_PROFILE_NAME")
6363

64+
JWT_AUTH_HEADER = os.environ.get("JWT_AUTH_HEADER", "HTTP_AUTHORIZATION")
65+
6466
# Cognito configuration
6567
COGNITO_AWS_REGION = os.environ.get("COGNITO_AWS_REGION")
66-
COGNITO_JWT_AUTH_HEADER = os.environ.get("COGNITO_JWT_AUTH_HEADER")
6768
COGNITO_USER_POOL = os.environ.get("COGNITO_USER_POOL")
6869

70+
# Entra configuration
71+
ENTRA_AUDIENCE = os.environ.get("ENTRA_AUDIENCE")
72+
ENTRA_APP_ID = os.environ.get("ENTRA_APP_ID")
73+
ENTRA_ALLOWED_APP_IDS = os.environ.get("ENTRA_ALLOWED_APP_IDS", "")
74+
ENTRA_TENANT_ID = os.environ.get("ENTRA_TENANT_ID")
75+
6976
# Database configuration
7077
POSTGRES_DB = os.environ.get("POSTGRES_DB")
7178
POSTGRES_USER = os.environ.get("POSTGRES_USER")

metrics/api/settings/default.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,18 @@
116116
},
117117
]
118118

119-
COGNITO_USER_MANAGER = "common.auth.cognito_jwt.user_manager.CognitoManager"
119+
JWT_AUTH_HEADER = config.JWT_AUTH_HEADER
120+
121+
ENTRA_USER_MANAGER = "common.auth.jwt.user_manager.EntraManager"
122+
ENTRA_AUDIENCE = config.ENTRA_AUDIENCE
123+
ENTRA_APP_ID = config.ENTRA_APP_ID
124+
ENTRA_ALLOWED_APP_IDS = config.ENTRA_ALLOWED_APP_IDS.split(",")
125+
ENTRA_TENANT_ID = config.ENTRA_TENANT_ID
126+
ENTRA_PUBLIC_KEYS_CACHING_ENABLED = True
127+
ENTRA_PUBLIC_KEYS_CACHING_TIMEOUT = 60 * 60 * 24 # 24h caching, default is 300s
128+
129+
COGNITO_USER_MANAGER = "common.auth.jwt.user_manager.CognitoManager"
120130
COGNITO_AWS_REGION = config.COGNITO_AWS_REGION
121-
COGNITO_JWT_AUTH_HEADER = config.COGNITO_JWT_AUTH_HEADER
122131
COGNITO_USER_POOL = config.COGNITO_USER_POOL
123132
COGNITO_AUDIENCE = None
124133
COGNITO_PUBLIC_KEYS_CACHING_ENABLED = True
@@ -128,7 +137,7 @@
128137
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
129138
"DEFAULT_AUTHENTICATION_CLASSES": [
130139
"rest_framework.authentication.SessionAuthentication",
131-
"common.auth.cognito_jwt.JSONWebTokenAuthentication",
140+
"common.auth.jwt.JSONWebTokenAuthentication",
132141
],
133142
}
134143

0 commit comments

Comments
 (0)