11import logging
22
3+ import jwt
34from django .apps import apps as django_apps
45from django .conf import settings
56from django .utils .encoding import force_str
89from rest_framework import HTTP_HEADER_ENCODING , exceptions
910from rest_framework .authentication import BaseAuthentication
1011
11- from .validator import TokenError , TokenValidator
12+ from .validator import CognitoTokenValidator , EntraTokenValidator , TokenError
1213
1314logger = logging .getLogger (__name__ )
1415
@@ -22,8 +23,9 @@ def get_authorization_header(request):
2223
2324 Hide some test client ickyness where the header can be unicode.
2425 """
25- auth_header = getattr (settings , "COGNITO_JWT_AUTH_HEADER " , "Authorization " )
26+ auth_header = getattr (settings , "JWT_AUTH_HEADER " , "HTTP_AUTHORIZATION " )
2627 auth = request .META .get (auth_header , b"" )
28+
2729 if isinstance (auth , str ):
2830 # Work around django test client oddness
2931 auth = auth .encode (HTTP_HEADER_ENCODING )
@@ -44,18 +46,28 @@ def authenticate(self, request):
4446
4547 # Authenticate token
4648 try :
47- token_validator = self .get_token_validator (request )
49+ token_validator , provider_name = self .get_token_validator (jwt_token )
50+ except TokenError as e :
51+ logger .debug ("Failed to identify token provider: %s" , e )
52+ raise exceptions .AuthenticationFailed (
53+ _ ("Unknown or malformed token issuer." )
54+ ) from e
55+
56+ try :
4857 jwt_payload = token_validator .validate (jwt_token )
49- except TokenError :
58+ except TokenError as e :
59+ logger .debug (
60+ "%s token validation failed: %s" , provider_name .capitalize (), e
61+ )
5062 raise exceptions .AuthenticationFailed from None
5163
52- custom_user_manager = self .get_custom_user_manager ()
64+ custom_user_manager = self .get_custom_user_manager (provider_name )
5365
5466 if custom_user_manager :
55- user = custom_user_manager .get_or_create_for_cognito (jwt_payload )
67+ user = custom_user_manager .get_or_create (jwt_payload )
5668 else :
5769 user_model = self .get_user_model ()
58- user = user_model .objects .get_or_create_for_cognito (jwt_payload )
70+ user = user_model .objects .get_or_create (jwt_payload )
5971
6072 if not user :
6173 logger .debug (
@@ -66,19 +78,27 @@ def authenticate(self, request):
6678 return (user , jwt_token )
6779
6880 @staticmethod
69- def get_custom_user_manager ():
70- """If COGNITO_USER_MANAGER is set, then the user object is obtained
71- via get_or_create_for_cognito on the user manager, this allows use
81+ def get_custom_user_manager (provider = "cognito" ):
82+ """If COGNITO_USER_MANAGER or ENTRA_USER_MANAGER is set, then the user object is obtained
83+ via get_or_create_for_cognito (or get_or_create_for_entra) on the user manager, this allows use
7284 of the default unmodified Django User model"""
7385 result = None
74- custom_user_manager_path = getattr (settings , "COGNITO_USER_MANAGER" , False )
86+ custom_user_manager_path = (
87+ getattr (settings , "ENTRA_USER_MANAGER" , False )
88+ if provider == "entra"
89+ else getattr (settings , "COGNITO_USER_MANAGER" , False )
90+ )
7591 if custom_user_manager_path :
7692 result = import_string (custom_user_manager_path )()
7793 return result
7894
7995 @staticmethod
80- def get_user_model ():
81- user_model = getattr (settings , "COGNITO_USER_MODEL" , settings .AUTH_USER_MODEL )
96+ def get_user_model (provider = "cognito" ):
97+ user_model = (
98+ getattr (settings , "ENTRA_USER_MODEL" , settings .AUTH_USER_MODEL )
99+ if provider == "entra"
100+ else getattr (settings , "COGNITO_USER_MODEL" , settings .AUTH_USER_MODEL )
101+ )
82102 return django_apps .get_model (user_model , require_ready = False )
83103
84104 @staticmethod
@@ -100,12 +120,33 @@ def get_jwt_token(request):
100120 return auth [1 ]
101121
102122 @staticmethod
103- def get_token_validator (request ):
104- return TokenValidator (
105- settings .COGNITO_AWS_REGION ,
106- settings .COGNITO_USER_POOL ,
107- settings .COGNITO_AUDIENCE ,
108- )
123+ def get_token_validator (jwt_token ):
124+ try :
125+ # Decode without verifying signature just to read the header/payload
126+ unverified_payload = jwt .decode (
127+ jwt_token , options = {"verify_signature" : False } # noqa: S5659
128+ )
129+ issuer = unverified_payload .get ("iss" , "" )
130+ except jwt .PyJWTError as e :
131+ raise exceptions .AuthenticationFailed (_ ("Malformed JWT." )) from e
132+
133+ if "cognito-idp" in issuer :
134+ validator = CognitoTokenValidator (
135+ settings .COGNITO_AWS_REGION ,
136+ settings .COGNITO_USER_POOL ,
137+ settings .COGNITO_AUDIENCE ,
138+ )
139+ return validator , "cognito"
140+
141+ if "sts.windows.net" in issuer :
142+ validator = EntraTokenValidator (
143+ settings .ENTRA_TENANT_ID ,
144+ settings .ENTRA_AUDIENCE ,
145+ settings .ENTRA_ALLOWED_APP_IDS ,
146+ )
147+ return validator , "entra"
148+
149+ raise exceptions .AuthenticationFailed (_ ("Invalid or unsupported token issuer." ))
109150
110151 @staticmethod
111152 def authenticate_header (request ):
0 commit comments