Skip to content

Commit 3d7663e

Browse files
Implement remote MFA with new, separate RPC message (#238)
* Refactor remote MFA flow to use ClientRemoteMfaFinish RPC * handle remote mfa core comms in ws thread * longer timeout for remote mfa * Use RwLocks instead of Mutexes * remove unused proto * rename rpc method * update protos
1 parent eeff92e commit 3d7663e

File tree

12 files changed

+97
-105
lines changed

12 files changed

+97
-105
lines changed

proto

src/enterprise/handlers/desktop_client_mfa.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub(super) async fn mfa_auth_callback(
9090
device_info,
9191
)?;
9292

93-
let payload = get_core_response(rx).await?;
93+
let payload = get_core_response(rx, None).await?;
9494

9595
if let core_response::Payload::Empty(()) = payload {
9696
info!("MFA authentication callback completed successfully");

src/enterprise/handlers/openid_login.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async fn auth_info(
7676
let rx = state
7777
.grpc_server
7878
.send(core_request::Payload::AuthInfo(request), device_info)?;
79-
let payload = get_core_response(rx).await?;
79+
let payload = get_core_response(rx, None).await?;
8080
if let core_response::Payload::AuthInfo(response) = payload {
8181
debug!("Received auth info response");
8282

@@ -164,7 +164,7 @@ async fn auth_callback(
164164
let rx = state
165165
.grpc_server
166166
.send(core_request::Payload::AuthCallback(request), device_info)?;
167-
let payload = get_core_response(rx).await?;
167+
let payload = get_core_response(rx, None).await?;
168168

169169
if let core_response::Payload::AuthCallback(AuthCallbackResponse { url, token }) = payload {
170170
debug!("Received auth callback response {url:?} {token:?}");

src/grpc.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ pub struct Configuration {
4040

4141
pub(crate) struct ProxyServer {
4242
current_id: Arc<AtomicU64>,
43-
clients: Arc<Mutex<ClientMap>>,
44-
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
43+
clients: Arc<RwLock<ClientMap>>,
44+
results: Arc<RwLock<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
4545
pub(crate) connected: Arc<AtomicBool>,
4646
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
4747
config: Arc<Mutex<Option<Configuration>>>,
@@ -55,8 +55,8 @@ impl ProxyServer {
5555
Self {
5656
cookie_key,
5757
current_id: Arc::new(AtomicU64::new(1)),
58-
clients: Arc::new(Mutex::new(HashMap::new())),
59-
results: Arc::new(Mutex::new(HashMap::new())),
58+
clients: Arc::new(RwLock::new(HashMap::new())),
59+
results: Arc::new(RwLock::new(HashMap::new())),
6060
connected: Arc::new(AtomicBool::new(false)),
6161
core_version: Arc::new(Mutex::new(None)),
6262
config: Arc::new(Mutex::new(None)),
@@ -126,7 +126,7 @@ impl ProxyServer {
126126
) -> Result<oneshot::Receiver<core_response::Payload>, ApiError> {
127127
if let Some(client_tx) = self
128128
.clients
129-
.lock()
129+
.read()
130130
.expect("Failed to acquire lock on clients hashmap when sending message to core")
131131
.values()
132132
.next()
@@ -143,7 +143,7 @@ impl ProxyServer {
143143
}
144144
let (tx, rx) = oneshot::channel();
145145
self.results
146-
.lock()
146+
.write()
147147
.expect("Failed to acquire lock on results hashmap when sending CoreRequest")
148148
.insert(id, tx);
149149
self.connected.store(true, Ordering::Relaxed);
@@ -214,7 +214,7 @@ impl proxy_server::Proxy for ProxyServer {
214214
info!("Defguard Core gRPC client connected from: {address}");
215215
let (tx, rx) = mpsc::unbounded_channel();
216216
self.clients
217-
.lock()
217+
.write()
218218
.expect(
219219
"Failed to acquire lock on clients hashmap when registering new core connection",
220220
)
@@ -241,7 +241,7 @@ impl proxy_server::Proxy for ProxyServer {
241241
*cookie_key.write().unwrap() = Some(key);
242242
},
243243
_ => {
244-
let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
244+
let maybe_rx = results.write().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
245245
if let Some(rx) = maybe_rx {
246246
if let Err(err) = rx.send(payload) {
247247
error!("Failed to send message to rx {:?}", err.type_id());
@@ -265,7 +265,7 @@ impl proxy_server::Proxy for ProxyServer {
265265
}
266266
info!("Defguard core client disconnected: {address}");
267267
connected.store(false, Ordering::Relaxed);
268-
clients.lock().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address);
268+
clients.write().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address);
269269
}
270270
.instrument(tracing::Span::current()),
271271
);

src/handlers/desktop_client_mfa.rs

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::hash_map::Entry;
1+
use std::time::Duration;
22

33
use axum::{
44
extract::{
@@ -12,18 +12,23 @@ use axum::{
1212
use futures_util::{sink::SinkExt, stream::StreamExt};
1313
use serde::Deserialize;
1414
use serde_json::json;
15-
use tokio::{sync::oneshot, task::JoinSet};
15+
use tokio::task::JoinSet;
1616

1717
use crate::{
1818
error::ApiError,
1919
handlers::get_core_response,
2020
http::AppState,
2121
proto::{
22-
core_request, core_response, ClientMfaFinishRequest, ClientMfaFinishResponse,
22+
core_request,
23+
core_response::{self, Payload},
24+
AwaitRemoteMfaFinishRequest, ClientMfaFinishRequest, ClientMfaFinishResponse,
2325
ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo,
2426
},
2527
};
2628

29+
// How much time the user has to approve remote MFA with mobile device
30+
const REMOTE_AUTH_TIMEOUT: Duration = Duration::from_secs(60);
31+
2732
pub(crate) fn router() -> Router<AppState> {
2833
Router::new()
2934
.route("/start", post(start_client_mfa))
@@ -53,66 +58,74 @@ async fn await_remote_auth(
5358
token: token.clone(),
5459
},
5560
),
56-
device_info,
61+
device_info.clone(),
5762
)?;
58-
let payload = get_core_response(rx).await?;
63+
let payload = get_core_response(rx, Some(REMOTE_AUTH_TIMEOUT)).await?;
5964
if let core_response::Payload::ClientMfaTokenValidation(response) = payload {
6065
if !response.token_valid {
6166
return Err(ApiError::Unauthorized(String::new()));
6267
}
63-
// check if its already in the map
64-
let contains_key = {
65-
let sessions = state.remote_mfa_sessions.lock().await;
66-
sessions.contains_key(&token)
67-
};
68-
if contains_key {
69-
return Err(ApiError::Unauthorized(String::new()));
70-
}
71-
Ok(ws.on_upgrade(move |socket| handle_remote_auth_socket(socket, state.clone(), token)))
68+
69+
Ok(ws.on_upgrade(move |socket| {
70+
handle_remote_auth_socket(socket, state.clone(), token, device_info)
71+
}))
7272
} else {
7373
Err(ApiError::InvalidResponseType)
7474
}
7575
}
7676

7777
/// Handle axum web socket upgrade for `await_remote_auth`.
78-
async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: String) {
79-
let (tx, rx) = oneshot::channel();
80-
81-
{
82-
let mut sessions = state.remote_mfa_sessions.lock().await;
83-
match sessions.entry(token.clone()) {
84-
Entry::Occupied(_) => {
85-
return;
86-
}
87-
Entry::Vacant(v) => {
88-
v.insert(tx);
89-
}
90-
}
91-
}
92-
78+
async fn handle_remote_auth_socket(
79+
socket: WebSocket,
80+
state: AppState,
81+
token: String,
82+
device_info: DeviceInfo,
83+
) {
9384
let (mut ws_tx, mut ws_rx) = socket.split();
9485
let mut set = JoinSet::new();
9586

87+
let request = AwaitRemoteMfaFinishRequest { token };
88+
let rx = match state.grpc_server.send(
89+
core_request::Payload::AwaitRemoteMfaFinish(request),
90+
device_info,
91+
) {
92+
Ok(rx) => rx,
93+
Err(err) => {
94+
error!("Failed to send ClientRemoteMfaFinishRequest: {err:?}");
95+
return;
96+
}
97+
};
98+
99+
// Response to ClientRemoteMfaFinishRequest comes once the user concludes MFA with mobile device.
100+
// This task then sends the preshared key to the WebSocket where desktop client awaits for it.
96101
set.spawn(async move {
97-
if let Ok(msg) = rx.await {
98-
let payload = json!({
99-
"type": "mfa_success",
100-
"preshared_key": &msg,
101-
});
102-
if let Ok(serialized) = serde_json::to_string(&payload) {
103-
let message = Message::Text(serialized.into());
104-
if ws_tx.send(message).await.is_err() {
105-
error!("Failed to send preshared key via ws");
102+
match rx.await {
103+
Ok(Payload::AwaitRemoteMfaFinish(response)) => {
104+
let ws_response = json!({
105+
"type": "mfa_success",
106+
"preshared_key": &response.preshared_key,
107+
});
108+
if let Ok(serialized) = serde_json::to_string(&ws_response) {
109+
let message = Message::Text(serialized.into());
110+
if let Err(err) = ws_tx.send(message).await {
111+
error!("Failed to send preshared key via ws: {err:?}");
112+
}
106113
}
107-
} else {
108-
error!("Failed to serialize remote mfa ws client response message");
109114
}
110-
} else {
111-
error!("Failed to receive preshared key from receiver");
112-
}
115+
Ok(_) => {
116+
error!("Received wrong response type, expected ClientRemoteMfaFinish");
117+
}
118+
Err(err) => {
119+
error!("Failed to receive preshared key from receiver: {err:?}");
120+
}
121+
};
122+
123+
// Close the websocket once we're done.
113124
let _ = ws_tx.close().await;
114125
});
115126

127+
// Another task to monitor the websocket connection in case desktop client disconnects
128+
// or the connection errors-out.
116129
set.spawn(async move {
117130
while let Some(msg_result) = ws_rx.next().await {
118131
match msg_result {
@@ -129,10 +142,9 @@ async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: St
129142
}
130143
});
131144

145+
// Wait for whichever task finishes first and kill the other one.
132146
let _ = set.join_next().await;
133147
set.shutdown().await;
134-
// This will remove token, if it's still there.
135-
state.remote_mfa_sessions.lock().await.remove(&token);
136148
}
137149

138150
#[instrument(level = "debug", skip(state, req))]
@@ -146,7 +158,7 @@ async fn start_client_mfa(
146158
core_request::Payload::ClientMfaStart(req.clone()),
147159
device_info,
148160
)?;
149-
let payload = get_core_response(rx).await?;
161+
let payload = get_core_response(rx, None).await?;
150162

151163
if let core_response::Payload::ClientMfaStart(response) = payload {
152164
info!("Started desktop client authorization {req:?}");
@@ -167,7 +179,7 @@ async fn finish_client_mfa(
167179
let rx = state
168180
.grpc_server
169181
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
170-
let payload = get_core_response(rx).await?;
182+
let payload = get_core_response(rx, None).await?;
171183
if let core_response::Payload::ClientMfaFinish(response) = payload {
172184
Ok(Json(response))
173185
} else {
@@ -186,32 +198,10 @@ async fn finish_remote_mfa(
186198
let rx = state
187199
.grpc_server
188200
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
189-
let payload = get_core_response(rx).await?;
190-
if let core_response::Payload::ClientMfaFinish(response) = payload {
191-
// Check if this needs to be forwarded.
192-
if let Some(token) = response.token {
193-
let sender_option = {
194-
let mut sessions = state.remote_mfa_sessions.lock().await;
195-
sessions.remove(&token)
196-
};
197-
if let Some(sender) = sender_option {
198-
let _ = sender.send(response.preshared_key);
199-
}
200-
// If desktop stopped listening for the result, there will be no place to send the
201-
// result.
202-
else {
203-
error!("Remote MFA approve finished but session was not found.");
204-
return Err(ApiError::Unexpected(String::new()));
205-
}
206-
207-
info!("Finished desktop client authorization via mobile device");
208-
Ok(Json(json!({})))
209-
} else {
210-
error!("Remote MFA Unexpected core response, token was not returned");
211-
Err(ApiError::Unexpected(String::new()))
212-
}
201+
if let core_response::Payload::ClientMfaFinish(_response) = get_core_response(rx, None).await? {
202+
Ok(Json(json!({})))
213203
} else {
214-
error!("Received invalid gRPC response type");
204+
error!("Received invalid gRPC response type, expected ClientMfaFinish");
215205
Err(ApiError::InvalidResponseType)
216206
}
217207
}

src/handlers/enrollment.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async fn start_enrollment_process(
4545
let rx = state
4646
.grpc_server
4747
.send(core_request::Payload::EnrollmentStart(req), device_info)?;
48-
let payload = get_core_response(rx).await?;
48+
let payload = get_core_response(rx, None).await?;
4949
debug!("Receving payload from the core service. Try to set private cookie for starting enrollment process.");
5050
if let core_response::Payload::EnrollmentStart(response) = payload {
5151
info!(
@@ -83,7 +83,7 @@ async fn activate_user(
8383
let rx = state
8484
.grpc_server
8585
.send(core_request::Payload::ActivateUser(req), device_info)?;
86-
let payload = get_core_response(rx).await?;
86+
let payload = get_core_response(rx, None).await?;
8787
debug!("Receiving payload from the core service. Trying to remove private cookie...");
8888
if let core_response::Payload::Empty(()) = payload {
8989
info!("Activated user - phone number {phone:?}");
@@ -116,7 +116,7 @@ async fn create_device(
116116
let rx = state
117117
.grpc_server
118118
.send(core_request::Payload::NewDevice(req), device_info)?;
119-
let payload = get_core_response(rx).await?;
119+
let payload = get_core_response(rx, None).await?;
120120
if let core_response::Payload::DeviceConfig(response) = payload {
121121
info!("Added new device {name} {pubkey}");
122122
Ok(Json(response))
@@ -144,7 +144,7 @@ async fn get_network_info(
144144
let rx = state
145145
.grpc_server
146146
.send(core_request::Payload::ExistingDevice(req), device_info)?;
147-
let payload = get_core_response(rx).await?;
147+
let payload = get_core_response(rx, None).await?;
148148
if let core_response::Payload::DeviceConfig(response) = payload {
149149
info!("Got network info for device {pubkey}");
150150
Ok(Json(response))

src/handlers/mobile_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub(crate) async fn register_mobile_auth(
5353
core_request::Payload::RegisterMobileAuth(send_data),
5454
device_info,
5555
)?;
56-
let payload = get_core_response(rx).await?;
56+
let payload = get_core_response(rx, None).await?;
5757
if let core_response::Payload::Empty(()) = payload {
5858
info!("Registered mobile device for auth");
5959
Ok(())

src/handlers/mod.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::time::Duration;
33
use axum::{extract::FromRequestParts, http::request::Parts};
44
use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor};
55
use axum_extra::{headers::UserAgent, TypedHeader};
6-
use tokio::{sync::oneshot::Receiver, time::timeout};
6+
use tokio::{sync::oneshot::Receiver, time};
77
use tonic::Code;
88

99
use super::proto::DeviceInfo;
@@ -69,9 +69,12 @@ where
6969
/// Helper which awaits core response
7070
///
7171
/// Waits for core response with a given timeout and returns the response payload.
72-
pub(crate) async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
72+
pub(crate) async fn get_core_response(
73+
rx: Receiver<Payload>,
74+
timeout: Option<Duration>,
75+
) -> Result<Payload, ApiError> {
7376
debug!("Fetching core response.");
74-
if let Ok(core_response) = timeout(CORE_RESPONSE_TIMEOUT, rx).await {
77+
if let Ok(core_response) = time::timeout(timeout.unwrap_or(CORE_RESPONSE_TIMEOUT), rx).await {
7578
debug!("Got gRPC response from Defguard Core");
7679
if let Ok(Payload::CoreError(core_error)) = core_response {
7780
if core_error.status_code == Code::FailedPrecondition as i32
@@ -92,7 +95,10 @@ pub(crate) async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload,
9295
core_response
9396
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
9497
} else {
95-
error!("Did not receive response from Core within {CORE_RESPONSE_TIMEOUT:?}");
98+
error!(
99+
"Did not receive response from Core within {:?}",
100+
timeout.unwrap_or(CORE_RESPONSE_TIMEOUT)
101+
);
96102
Err(ApiError::CoreTimeout)
97103
}
98104
}

0 commit comments

Comments
 (0)