11import base64
22import hashlib
33import time
4+ from base64 import b64decode
45from dataclasses import dataclass
56from typing import Annotated , Any , Literal
67
@@ -19,9 +20,6 @@ class AuthorizationCodeRequest(BaseModel):
1920 grant_type : Literal ["authorization_code" ]
2021 code : str = Field (..., description = "The authorization code" )
2122 redirect_uri : AnyUrl | None = Field (None , description = "Must be the same as redirect URI provided in /authorize" )
22- client_id : str
23- # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
24- client_secret : str | None = None
2523 # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
2624 code_verifier : str = Field (..., description = "PKCE code verifier" )
2725
@@ -31,9 +29,50 @@ class RefreshTokenRequest(BaseModel):
3129 grant_type : Literal ["refresh_token" ]
3230 refresh_token : str = Field (..., description = "The refresh token" )
3331 scope : str | None = Field (None , description = "Optional scope parameter" )
32+
33+
34+ class NoneCredentials (BaseModel ):
35+ client_id : str
36+ client_secret : None = None
37+
38+
39+ class PostCredentials (BaseModel ):
3440 client_id : str
3541 # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
36- client_secret : str | None = None
42+ client_secret : str
43+
44+
45+ class FormCredentials (
46+ RootModel [
47+ Annotated [
48+ NoneCredentials | PostCredentials ,
49+ Field (discriminator = "client_secret" ),
50+ ]
51+ ]
52+ ):
53+ root : Annotated [
54+ NoneCredentials | PostCredentials ,
55+ Field (discriminator = "client_secret" ),
56+ ]
57+
58+
59+ class BasicCredentials (BaseModel ):
60+ client_id : str
61+ client_secret : str
62+
63+ @classmethod
64+ def from_authorization (cls , authorization : str ):
65+ try :
66+ if authorization .startswith ("Basic " ):
67+ [client_id , client_secret ] = b64decode (authorization .removeprefix ("Basic " )).decode ().split (":" , 1 )
68+ return cls (client_id = client_id , client_secret = client_secret )
69+ except Exception :
70+ # TODO: better error here??
71+ return None
72+ return None
73+
74+
75+ Credentials = NoneCredentials | PostCredentials | BasicCredentials
3776
3877
3978class TokenRequest (
@@ -90,19 +129,42 @@ async def handle(self, request: Request):
90129 try :
91130 form_data = await request .form ()
92131 token_request = TokenRequest .model_validate (dict (form_data )).root
132+ try :
133+ credentials = FormCredentials .model_validate (dict (form_data )).root
134+ except ValidationError :
135+ credentials = (
136+ BasicCredentials .from_authorization (authorization )
137+ if (authorization := request .headers .get ("Authorization" ))
138+ else None
139+ )
140+ if not credentials :
141+ return self .response (
142+ TokenErrorResponse (
143+ error = "invalid_request" ,
144+ error_description = "missing credentials" ,
145+ )
146+ )
93147 except ValidationError as validation_error :
94148 return self .response (
95149 TokenErrorResponse (
96150 error = "invalid_request" ,
97151 error_description = stringify_pydantic_error (validation_error ),
98152 )
99153 )
100-
101154 try :
102155 client_info = await self .client_authenticator .authenticate (
103- client_id = token_request .client_id ,
104- client_secret = token_request .client_secret ,
156+ client_id = credentials .client_id ,
157+ client_secret = credentials .client_secret ,
105158 )
159+ match client_info .token_endpoint_auth_method :
160+ case "none" if not isinstance (credentials , NoneCredentials ):
161+ raise AuthenticationError ("Invalid credentials for client token_endpoint_auth_method" )
162+ case "client_secret_post" if not isinstance (credentials , PostCredentials ):
163+ raise AuthenticationError ("Invalid credentials for client token_endpoint_auth_method" )
164+ case "client_secret_basic" if not isinstance (credentials , BasicCredentials ):
165+ raise AuthenticationError ("Invalid credentials for client token_endpoint_auth_method" )
166+ case _:
167+ pass
106168 except AuthenticationError as e :
107169 return self .response (
108170 TokenErrorResponse (
@@ -126,7 +188,7 @@ async def handle(self, request: Request):
126188 match token_request :
127189 case AuthorizationCodeRequest ():
128190 auth_code = await self .provider .load_authorization_code (client_info , token_request .code )
129- if auth_code is None or auth_code .client_id != token_request .client_id :
191+ if auth_code is None or auth_code .client_id != credentials .client_id :
130192 # if code belongs to different client, pretend it doesn't exist
131193 return self .response (
132194 TokenErrorResponse (
@@ -185,7 +247,7 @@ async def handle(self, request: Request):
185247
186248 case RefreshTokenRequest ():
187249 refresh_token = await self .provider .load_refresh_token (client_info , token_request .refresh_token )
188- if refresh_token is None or refresh_token .client_id != token_request .client_id :
250+ if refresh_token is None or refresh_token .client_id != credentials .client_id :
189251 # if token belongs to different client, pretend it doesn't exist
190252 return self .response (
191253 TokenErrorResponse (
0 commit comments