1- use axum:: {
2- extract:: State ,
3- routing:: { get, post} ,
4- Json , Router ,
5- } ;
1+ use axum:: { extract:: State , routing:: post, Json , Router } ;
62use axum_extra:: extract:: {
73 cookie:: { Cookie , SameSite } ,
84 PrivateCookieJar ,
@@ -11,6 +7,7 @@ use serde::{Deserialize, Serialize};
117use time:: Duration ;
128
139use crate :: {
10+ enterprise:: handlers:: desktop_client_mfa:: mfa_auth_callback,
1411 error:: ApiError ,
1512 handlers:: get_core_response,
1613 http:: AppState ,
@@ -21,13 +18,14 @@ use crate::{
2118} ;
2219
2320const COOKIE_MAX_AGE : Duration = Duration :: days ( 1 ) ;
24- static CSRF_COOKIE_NAME : & str = "csrf_proxy" ;
25- static NONCE_COOKIE_NAME : & str = "nonce_proxy" ;
21+ pub ( super ) static CSRF_COOKIE_NAME : & str = "csrf_proxy" ;
22+ pub ( super ) static NONCE_COOKIE_NAME : & str = "nonce_proxy" ;
2623
2724pub ( crate ) fn router ( ) -> Router < AppState > {
2825 Router :: new ( )
29- . route ( "/auth_info" , get ( auth_info) )
26+ . route ( "/auth_info" , post ( auth_info) )
3027 . route ( "/callback" , post ( auth_callback) )
28+ . route ( "/callback/mfa" , post ( mfa_auth_callback) )
3129}
3230
3331#[ derive( Serialize ) ]
@@ -46,17 +44,49 @@ impl AuthInfo {
4644 }
4745}
4846
47+ #[ derive( Deserialize , Debug , PartialEq , Eq ) ]
48+ pub ( crate ) enum FlowType {
49+ Enrollment ,
50+ Mfa ,
51+ }
52+
53+ impl std:: str:: FromStr for FlowType {
54+ type Err = ( ) ;
55+
56+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
57+ match s. to_lowercase ( ) . as_str ( ) {
58+ "enrollment" => Ok ( FlowType :: Enrollment ) ,
59+ "mfa" => Ok ( FlowType :: Mfa ) ,
60+ _ => Err ( ( ) ) ,
61+ }
62+ }
63+ }
64+
65+ #[ derive( Deserialize , Debug ) ]
66+ struct RequestData {
67+ state : Option < String > ,
68+ #[ serde( rename = "type" ) ]
69+ flow_type : String ,
70+ }
71+
4972/// Request external OAuth2/OpenID provider details from Defguard Core.
5073#[ instrument( level = "debug" , skip( state) ) ]
5174async fn auth_info (
5275 State ( state) : State < AppState > ,
5376 device_info : DeviceInfo ,
5477 private_cookies : PrivateCookieJar ,
78+ Json ( request_data) : Json < RequestData > ,
5579) -> Result < ( PrivateCookieJar , Json < AuthInfo > ) , ApiError > {
5680 debug ! ( "Getting auth info for OAuth2/OpenID login" ) ;
5781
82+ let flow_type = request_data
83+ . flow_type
84+ . parse :: < FlowType > ( )
85+ . map_err ( |_| ApiError :: BadRequest ( "Invalid flow type" . into ( ) ) ) ?;
86+
5887 let request = AuthInfoRequest {
59- redirect_url : state. callback_url ( ) . to_string ( ) ,
88+ redirect_url : state. callback_url ( flow_type) . to_string ( ) ,
89+ state : request_data. state ,
6090 } ;
6191
6292 let rx = state
@@ -93,9 +123,11 @@ async fn auth_info(
93123}
94124
95125#[ derive( Debug , Deserialize ) ]
96- pub struct AuthenticationResponse {
97- code : String ,
98- state : String ,
126+ pub ( super ) struct AuthenticationResponse {
127+ pub ( super ) code : String ,
128+ pub ( super ) state : String ,
129+ #[ serde( rename = "type" ) ]
130+ pub ( super ) flow_type : String ,
99131}
100132
101133#[ derive( Serialize ) ]
@@ -111,6 +143,17 @@ async fn auth_callback(
111143 mut private_cookies : PrivateCookieJar ,
112144 Json ( payload) : Json < AuthenticationResponse > ,
113145) -> Result < ( PrivateCookieJar , Json < CallbackResponseData > ) , ApiError > {
146+ let flow_type = payload
147+ . flow_type
148+ . parse :: < FlowType > ( )
149+ . map_err ( |_| ApiError :: BadRequest ( "Invalid flow type" . into ( ) ) ) ?;
150+
151+ if flow_type != FlowType :: Enrollment {
152+ return Err ( ApiError :: BadRequest (
153+ "Invalid flow type for OpenID enrollment callback" . into ( ) ,
154+ ) ) ;
155+ }
156+
114157 let nonce = private_cookies
115158 . get ( NONCE_COOKIE_NAME )
116159 . ok_or ( ApiError :: Unauthorized ( "Nonce cookie not found" . into ( ) ) ) ?
@@ -133,13 +176,14 @@ async fn auth_callback(
133176 let request = AuthCallbackRequest {
134177 code : payload. code ,
135178 nonce,
136- callback_url : state. callback_url ( ) . to_string ( ) ,
179+ callback_url : state. callback_url ( flow_type ) . to_string ( ) ,
137180 } ;
138181
139182 let rx = state
140183 . grpc_server
141184 . send ( core_request:: Payload :: AuthCallback ( request) , device_info) ?;
142185 let payload = get_core_response ( rx) . await ?;
186+
143187 if let core_response:: Payload :: AuthCallback ( AuthCallbackResponse { url, token } ) = payload {
144188 debug ! ( "Received auth callback response {url:?} {token:?}" ) ;
145189 Ok ( ( private_cookies, Json ( CallbackResponseData { url, token } ) ) )
0 commit comments