@@ -2,9 +2,10 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
22
33use async_trait:: async_trait;
44use oauth2:: {
5- AuthType , AuthUrl , AuthorizationCode , ClientId , ClientSecret , CsrfToken , EmptyExtraTokenFields ,
6- PkceCodeChallenge , PkceCodeVerifier , RedirectUrl , RefreshToken , RequestTokenError , Scope ,
7- StandardTokenResponse , TokenResponse , TokenUrl ,
5+ AsyncHttpClient , AuthType , AuthUrl , AuthorizationCode , ClientId , ClientSecret , CsrfToken ,
6+ EmptyExtraTokenFields , HttpClientError , HttpRequest , HttpResponse , PkceCodeChallenge ,
7+ PkceCodeVerifier , RedirectUrl , RefreshToken , RequestTokenError , Scope , StandardTokenResponse ,
8+ TokenResponse , TokenUrl ,
89 basic:: { BasicClient , BasicTokenType } ,
910} ;
1011use reqwest:: {
@@ -18,6 +19,39 @@ use tracing::{debug, error, warn};
1819
1920use crate :: transport:: common:: http_header:: HEADER_MCP_PROTOCOL_VERSION ;
2021
22+ /// Owned wrapper around [`reqwest::Client`] that implements [`AsyncHttpClient`] for oauth2.
23+ struct OAuthReqwestClient ( HttpClient ) ;
24+
25+ impl < ' c > AsyncHttpClient < ' c > for OAuthReqwestClient {
26+ type Error = HttpClientError < reqwest:: Error > ;
27+
28+ type Future = std:: pin:: Pin <
29+ Box < dyn std:: future:: Future < Output = Result < HttpResponse , Self :: Error > > + Send + Sync + ' c > ,
30+ > ;
31+
32+ fn call ( & ' c self , request : HttpRequest ) -> Self :: Future {
33+ Box :: pin ( async move {
34+ let response = self
35+ . 0
36+ . execute ( request. try_into ( ) . map_err ( Box :: new) ?)
37+ . await
38+ . map_err ( Box :: new) ?;
39+
40+ let mut builder = oauth2:: http:: Response :: builder ( )
41+ . status ( response. status ( ) )
42+ . version ( response. version ( ) ) ;
43+
44+ for ( name, value) in response. headers ( ) . iter ( ) {
45+ builder = builder. header ( name, value) ;
46+ }
47+
48+ builder
49+ . body ( response. bytes ( ) . await . map_err ( Box :: new) ?. to_vec ( ) )
50+ . map_err ( HttpClientError :: Http )
51+ } )
52+ }
53+ }
54+
2155const DEFAULT_EXCHANGE_URL : & str = "http://localhost" ;
2256
2357/// Stored credentials for OAuth2 authorization
@@ -872,7 +906,7 @@ impl AuthorizationManager {
872906 . exchange_code ( AuthorizationCode :: new ( code. to_string ( ) ) )
873907 . set_pkce_verifier ( pkce_verifier)
874908 . add_extra_param ( "resource" , self . base_url . to_string ( ) )
875- . request_async ( & http_client)
909+ . request_async ( & OAuthReqwestClient ( http_client) )
876910 . await
877911 {
878912 Ok ( token) => token,
@@ -961,7 +995,7 @@ impl AuthorizationManager {
961995
962996 let token_result = oauth_client
963997 . exchange_refresh_token ( & RefreshToken :: new ( refresh_token. secret ( ) . to_string ( ) ) )
964- . request_async ( & self . http_client )
998+ . request_async ( & OAuthReqwestClient ( self . http_client . clone ( ) ) )
965999 . await
9661000 . map_err ( |e| AuthError :: TokenRefreshFailed ( e. to_string ( ) ) ) ?;
9671001
0 commit comments