1- use std:: collections :: hash_map :: Entry ;
1+ use std:: time :: Duration ;
22
33use axum:: {
44 extract:: {
@@ -12,18 +12,23 @@ use axum::{
1212use futures_util:: { sink:: SinkExt , stream:: StreamExt } ;
1313use serde:: Deserialize ;
1414use serde_json:: json;
15- use tokio:: { sync :: oneshot , task:: JoinSet } ;
15+ use tokio:: task:: JoinSet ;
1616
1717use crate :: {
1818 error:: ApiError ,
1919 handlers:: get_core_response,
2020 http:: AppState ,
2121 proto:: {
22- core_request, core_response, ClientMfaFinishRequest , ClientMfaFinishResponse ,
22+ core_request,
23+ core_response:: { self , Payload } ,
24+ AwaitRemoteMfaFinishRequest , ClientMfaFinishRequest , ClientMfaFinishResponse ,
2325 ClientMfaStartRequest , ClientMfaStartResponse , DeviceInfo ,
2426 } ,
2527} ;
2628
29+ // How much time the user has to approve remote MFA with mobile device
30+ const REMOTE_AUTH_TIMEOUT : Duration = Duration :: from_secs ( 60 ) ;
31+
2732pub ( crate ) fn router ( ) -> Router < AppState > {
2833 Router :: new ( )
2934 . route ( "/start" , post ( start_client_mfa) )
@@ -53,66 +58,74 @@ async fn await_remote_auth(
5358 token : token. clone ( ) ,
5459 } ,
5560 ) ,
56- device_info,
61+ device_info. clone ( ) ,
5762 ) ?;
58- let payload = get_core_response ( rx) . await ?;
63+ let payload = get_core_response ( rx, Some ( REMOTE_AUTH_TIMEOUT ) ) . await ?;
5964 if let core_response:: Payload :: ClientMfaTokenValidation ( response) = payload {
6065 if !response. token_valid {
6166 return Err ( ApiError :: Unauthorized ( String :: new ( ) ) ) ;
6267 }
63- // check if its already in the map
64- let contains_key = {
65- let sessions = state. remote_mfa_sessions . lock ( ) . await ;
66- sessions. contains_key ( & token)
67- } ;
68- if contains_key {
69- return Err ( ApiError :: Unauthorized ( String :: new ( ) ) ) ;
70- }
71- Ok ( ws. on_upgrade ( move |socket| handle_remote_auth_socket ( socket, state. clone ( ) , token) ) )
68+
69+ Ok ( ws. on_upgrade ( move |socket| {
70+ handle_remote_auth_socket ( socket, state. clone ( ) , token, device_info)
71+ } ) )
7272 } else {
7373 Err ( ApiError :: InvalidResponseType )
7474 }
7575}
7676
7777/// Handle axum web socket upgrade for `await_remote_auth`.
78- async fn handle_remote_auth_socket ( socket : WebSocket , state : AppState , token : String ) {
79- let ( tx, rx) = oneshot:: channel ( ) ;
80-
81- {
82- let mut sessions = state. remote_mfa_sessions . lock ( ) . await ;
83- match sessions. entry ( token. clone ( ) ) {
84- Entry :: Occupied ( _) => {
85- return ;
86- }
87- Entry :: Vacant ( v) => {
88- v. insert ( tx) ;
89- }
90- }
91- }
92-
78+ async fn handle_remote_auth_socket (
79+ socket : WebSocket ,
80+ state : AppState ,
81+ token : String ,
82+ device_info : DeviceInfo ,
83+ ) {
9384 let ( mut ws_tx, mut ws_rx) = socket. split ( ) ;
9485 let mut set = JoinSet :: new ( ) ;
9586
87+ let request = AwaitRemoteMfaFinishRequest { token } ;
88+ let rx = match state. grpc_server . send (
89+ core_request:: Payload :: AwaitRemoteMfaFinish ( request) ,
90+ device_info,
91+ ) {
92+ Ok ( rx) => rx,
93+ Err ( err) => {
94+ error ! ( "Failed to send ClientRemoteMfaFinishRequest: {err:?}" ) ;
95+ return ;
96+ }
97+ } ;
98+
99+ // Response to ClientRemoteMfaFinishRequest comes once the user concludes MFA with mobile device.
100+ // This task then sends the preshared key to the WebSocket where desktop client awaits for it.
96101 set. spawn ( async move {
97- if let Ok ( msg) = rx. await {
98- let payload = json ! ( {
99- "type" : "mfa_success" ,
100- "preshared_key" : & msg,
101- } ) ;
102- if let Ok ( serialized) = serde_json:: to_string ( & payload) {
103- let message = Message :: Text ( serialized. into ( ) ) ;
104- if ws_tx. send ( message) . await . is_err ( ) {
105- error ! ( "Failed to send preshared key via ws" ) ;
102+ match rx. await {
103+ Ok ( Payload :: AwaitRemoteMfaFinish ( response) ) => {
104+ let ws_response = json ! ( {
105+ "type" : "mfa_success" ,
106+ "preshared_key" : & response. preshared_key,
107+ } ) ;
108+ if let Ok ( serialized) = serde_json:: to_string ( & ws_response) {
109+ let message = Message :: Text ( serialized. into ( ) ) ;
110+ if let Err ( err) = ws_tx. send ( message) . await {
111+ error ! ( "Failed to send preshared key via ws: {err:?}" ) ;
112+ }
106113 }
107- } else {
108- error ! ( "Failed to serialize remote mfa ws client response message" ) ;
109114 }
110- } else {
111- error ! ( "Failed to receive preshared key from receiver" ) ;
112- }
115+ Ok ( _) => {
116+ error ! ( "Received wrong response type, expected ClientRemoteMfaFinish" ) ;
117+ }
118+ Err ( err) => {
119+ error ! ( "Failed to receive preshared key from receiver: {err:?}" ) ;
120+ }
121+ } ;
122+
123+ // Close the websocket once we're done.
113124 let _ = ws_tx. close ( ) . await ;
114125 } ) ;
115126
127+ // Another task to monitor the websocket connection in case desktop client disconnects
128+ // or the connection errors-out.
116129 set. spawn ( async move {
117130 while let Some ( msg_result) = ws_rx. next ( ) . await {
118131 match msg_result {
@@ -129,10 +142,9 @@ async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: St
129142 }
130143 } ) ;
131144
145+ // Wait for whichever task finishes first and kill the other one.
132146 let _ = set. join_next ( ) . await ;
133147 set. shutdown ( ) . await ;
134- // This will remove token, if it's still there.
135- state. remote_mfa_sessions . lock ( ) . await . remove ( & token) ;
136148}
137149
138150#[ instrument( level = "debug" , skip( state, req) ) ]
@@ -146,7 +158,7 @@ async fn start_client_mfa(
146158 core_request:: Payload :: ClientMfaStart ( req. clone ( ) ) ,
147159 device_info,
148160 ) ?;
149- let payload = get_core_response ( rx) . await ?;
161+ let payload = get_core_response ( rx, None ) . await ?;
150162
151163 if let core_response:: Payload :: ClientMfaStart ( response) = payload {
152164 info ! ( "Started desktop client authorization {req:?}" ) ;
@@ -167,7 +179,7 @@ async fn finish_client_mfa(
167179 let rx = state
168180 . grpc_server
169181 . send ( core_request:: Payload :: ClientMfaFinish ( req) , device_info) ?;
170- let payload = get_core_response ( rx) . await ?;
182+ let payload = get_core_response ( rx, None ) . await ?;
171183 if let core_response:: Payload :: ClientMfaFinish ( response) = payload {
172184 Ok ( Json ( response) )
173185 } else {
@@ -186,32 +198,10 @@ async fn finish_remote_mfa(
186198 let rx = state
187199 . grpc_server
188200 . send ( core_request:: Payload :: ClientMfaFinish ( req) , device_info) ?;
189- let payload = get_core_response ( rx) . await ?;
190- if let core_response:: Payload :: ClientMfaFinish ( response) = payload {
191- // Check if this needs to be forwarded.
192- if let Some ( token) = response. token {
193- let sender_option = {
194- let mut sessions = state. remote_mfa_sessions . lock ( ) . await ;
195- sessions. remove ( & token)
196- } ;
197- if let Some ( sender) = sender_option {
198- let _ = sender. send ( response. preshared_key ) ;
199- }
200- // If desktop stopped listening for the result, there will be no place to send the
201- // result.
202- else {
203- error ! ( "Remote MFA approve finished but session was not found." ) ;
204- return Err ( ApiError :: Unexpected ( String :: new ( ) ) ) ;
205- }
206-
207- info ! ( "Finished desktop client authorization via mobile device" ) ;
208- Ok ( Json ( json ! ( { } ) ) )
209- } else {
210- error ! ( "Remote MFA Unexpected core response, token was not returned" ) ;
211- Err ( ApiError :: Unexpected ( String :: new ( ) ) )
212- }
201+ if let core_response:: Payload :: ClientMfaFinish ( _response) = get_core_response ( rx, None ) . await ? {
202+ Ok ( Json ( json ! ( { } ) ) )
213203 } else {
214- error ! ( "Received invalid gRPC response type" ) ;
204+ error ! ( "Received invalid gRPC response type, expected ClientMfaFinish " ) ;
215205 Err ( ApiError :: InvalidResponseType )
216206 }
217207}
0 commit comments