1515import logging
1616import secrets
1717from base64 import b64decode , b64encode
18+ from typing import Literal
1819
1920import click
2021from pydantic import AnyHttpUrl , BaseModel
@@ -46,6 +47,8 @@ class DiscordOAuthSettings(BaseSettings):
4647 discord_client_id : str | None = None
4748 discord_client_secret : str | None = None
4849
50+ token_endpoint_auth_method : Literal ["client_secret_basic" , "client_secret_post" ] = "client_secret_basic"
51+
4952 # Discord OAuth URL
5053 discord_token_url : str = f"{ API_ENDPOINT } /oauth2/token"
5154
@@ -96,13 +99,21 @@ def __init__(self, scope: Scope, receive: Receive, send: Send):
9699 }
97100
98101 async def post (self , request : Request ) -> Response :
99- # Get client_id and client_secret from Basic auth header
100- auth_header = request .headers .get ("Authorization" , "" )
101- if not auth_header .startswith ("Basic " ):
102- return JSONResponse ({"error" : "Invalid authorization header" }, status_code = 401 )
103- auth_header_encoded = auth_header .split (" " )[1 ]
104- auth_header_raw = b64decode (auth_header_encoded ).decode ("utf-8" )
105- client_id , client_secret = auth_header_raw .split (":" )
102+ # Get request data (application/x-www-form-urlencoded)
103+ data = await request .form ()
104+
105+ if self .discord_settings .token_endpoint_auth_method == "client_secret_basic" :
106+ # Get client_id and client_secret from Basic auth header
107+ auth_header = request .headers .get ("Authorization" , "" )
108+ if not auth_header .startswith ("Basic " ):
109+ return JSONResponse ({"error" : "Invalid authorization header" }, status_code = 401 )
110+ auth_header_encoded = auth_header .split (" " )[1 ]
111+ auth_header_raw = b64decode (auth_header_encoded ).decode ("utf-8" )
112+ client_id , client_secret = auth_header_raw .split (":" )
113+ else :
114+ # Get from body
115+ client_id = str (data .get ("client_id" ))
116+ client_secret = str (data .get ("client_secret" ))
106117
107118 # Validate MCP client
108119 if client_id not in self .client_map :
@@ -115,9 +126,6 @@ async def post(self, request: Request) -> Response:
115126 discord_client_id = self .client_map [client_id ]
116127 discord_client_secret = self .discord_client_credentials [discord_client_id ]
117128
118- # Get request data (application/x-www-form-urlencoded)
119- data = await request .form ()
120-
121129 # Validate scopes
122130 scopes = str (data .get ("scope" , "" )).split (" " )
123131 if not set (scopes ).issubset (set (self .discord_settings .discord_scope .split (" " ))):
@@ -208,7 +216,7 @@ def create_authorization_server(
208216 issuer = server_settings .server_url ,
209217 authorization_endpoint = AnyHttpUrl (f"{ server_settings .server_url } authorize" ),
210218 token_endpoint = AnyHttpUrl (f"{ server_settings .server_url } token" ),
211- token_endpoint_auth_methods_supported = ["client_secret_basic " ],
219+ token_endpoint_auth_methods_supported = ["client_secret_post " ],
212220 response_types_supported = ["code" ],
213221 grant_types_supported = ["client_credentials" ],
214222 scopes_supported = [discord_settings .discord_scope ],
0 commit comments