Skip to content

Commit bdd55f7

Browse files
msirringhausiinuwa
authored andcommitted
Make hid_send/recv cancellable
1 parent c61492d commit bdd55f7

3 files changed

Lines changed: 174 additions & 12 deletions

File tree

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use std::error::Error;
2+
3+
#[cfg(feature = "virtual-hid-device")]
4+
#[tokio::main]
5+
pub async fn main() -> Result<(), Box<dyn Error>> {
6+
// This example doesn't work for virtual devices, because
7+
// solo devices are not clone-able.
8+
Ok(())
9+
}
10+
11+
#[cfg(not(feature = "virtual-hid-device"))]
12+
#[tokio::main]
13+
pub async fn main() -> Result<(), Box<dyn Error>> {
14+
use std::collections::HashMap;
15+
use std::time::Duration;
16+
17+
use libwebauthn::transport::hid::channel::HidChannelHandle;
18+
use tracing_subscriber::{self, EnvFilter};
19+
20+
use libwebauthn::transport::hid::{list_devices, HidDevice};
21+
use libwebauthn::transport::Device;
22+
23+
fn setup_logging() {
24+
tracing_subscriber::fmt()
25+
.with_env_filter(EnvFilter::from_default_env())
26+
.without_time()
27+
.init();
28+
}
29+
setup_logging();
30+
31+
let devices = list_devices().await.unwrap();
32+
let mut expected_answers = devices.len();
33+
let (blinking_tx, mut blinking_rx) =
34+
tokio::sync::mpsc::channel::<Option<usize>>(expected_answers);
35+
let mut channel_map = HashMap::new();
36+
let (setup_tx, mut setup_rx) =
37+
tokio::sync::mpsc::channel::<(usize, HidDevice, HidChannelHandle)>(expected_answers);
38+
39+
println!("Found {expected_answers} devices. Select one by touching.");
40+
for (idx, mut device) in devices.into_iter().enumerate() {
41+
let stx = setup_tx.clone();
42+
let btx = blinking_tx.clone();
43+
44+
tokio::spawn(async move {
45+
let dev = device.clone();
46+
let (mut channel, _state_rx) = device.channel().await.unwrap();
47+
let handle = channel.get_handle();
48+
stx.send((idx, dev, handle)).await.unwrap();
49+
drop(stx);
50+
51+
println!("Blinking {idx}");
52+
let res = channel
53+
.blink_and_wait_for_user_presence(Duration::from_secs(300))
54+
.await;
55+
match res {
56+
Ok(true) => {
57+
println!("Touch from {idx}");
58+
btx.send(Some(idx)).await.unwrap();
59+
}
60+
Ok(false) | Err(_) => {
61+
btx.send(None).await.unwrap();
62+
}
63+
}
64+
});
65+
}
66+
drop(setup_tx);
67+
while let Some((idx, device, handle)) = setup_rx.recv().await {
68+
channel_map.insert(idx, (device, handle));
69+
}
70+
71+
drop(blinking_tx);
72+
let mut found_one = false;
73+
while let Some(msg) = blinking_rx.recv().await {
74+
expected_answers -= 1;
75+
match msg {
76+
Some(idx) => {
77+
println!("Received {idx}");
78+
for (key, (_device, handle)) in channel_map.iter() {
79+
if key == &idx {
80+
continue;
81+
}
82+
println!("Cancelling {key}");
83+
handle.cancel_ongoing_operation().await;
84+
}
85+
let (device, _handle) = &channel_map[&idx];
86+
println!("User chosen device: {device:?}");
87+
found_one = true;
88+
}
89+
None => {
90+
if expected_answers == 0 {
91+
if found_one {
92+
println!("All devices finished.");
93+
} else {
94+
println!("No device was chosen. All timed out.");
95+
}
96+
break;
97+
} else {
98+
continue;
99+
}
100+
}
101+
}
102+
}
103+
104+
Ok(())
105+
}

libwebauthn/src/transport/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub enum PlatformError {
1919
SyntaxError,
2020
#[error("cbor serialization error: {0}")]
2121
CborError(#[from] CborError),
22+
#[error("cancelled by user")]
23+
Cancelled,
2224
}
2325

2426
#[derive(thiserror::Error, Debug, PartialEq)]

libwebauthn/src/transport/hid/channel.rs

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use std::convert::TryFrom;
22
use std::fmt::{Debug, Display, Formatter};
33
use std::io::{Cursor as IOCursor, Seek, SeekFrom};
4-
use std::ops::Deref;
4+
use std::ops::DerefMut;
55
use std::sync::{Arc, Mutex};
66
use std::time::Duration;
77

88
use async_trait::async_trait;
99
use byteorder::{BigEndian, ReadBytesExt};
1010
use hidapi::HidDevice as HidApiDevice;
1111
use rand::{thread_rng, Rng};
12-
use tokio::sync::mpsc;
12+
use tokio::sync::mpsc::error::TryRecvError;
13+
use tokio::sync::mpsc::{self, Receiver, Sender};
1314
use tokio::time::sleep;
1415
use tracing::{debug, info, instrument, trace, warn, Level};
1516

@@ -27,6 +28,7 @@ use crate::transport::error::{Error, TransportError};
2728
use crate::transport::hid::framing::{
2829
HidCommand, HidMessage, HidMessageParser, HidMessageParserState,
2930
};
31+
use crate::webauthn::PlatformError;
3032
use crate::UxUpdate;
3133

3234
use super::device::get_hidapi;
@@ -44,45 +46,66 @@ const REPORT_ID: u8 = 0x00;
4446
// by a CBOR command, so we want to ensure we wait some time after winking.
4547
const WINK_MIN_WAIT: Duration = Duration::from_secs(2);
4648

49+
pub type CancelHidOperation = ();
4750
enum OpenHidDevice {
48-
HidApiDevice(Arc<Mutex<HidApiDevice>>),
51+
HidApiDevice(Arc<Mutex<(HidApiDevice, mpsc::Receiver<CancelHidOperation>)>>),
4952
#[cfg(feature = "virtual-hid-device")]
5053
VirtualDevice,
5154
}
5255

56+
#[derive(Debug, Clone)]
57+
pub struct HidChannelHandle {
58+
tx: Sender<CancelHidOperation>,
59+
}
60+
61+
impl HidChannelHandle {
62+
pub async fn cancel_ongoing_operation(&self) {
63+
let _ = self.tx.send(()).await;
64+
}
65+
}
66+
5367
pub struct HidChannel<'d> {
5468
status: ChannelStatus,
5569
device: &'d HidDevice,
5670
open_device: OpenHidDevice,
5771
init: InitResponse,
5872
auth_token_data: Option<AuthTokenData>,
5973
tx: mpsc::Sender<UxUpdate>,
74+
handle: HidChannelHandle,
6075
}
6176

6277
impl<'d> HidChannel<'d> {
6378
pub async fn new(
6479
device: &'d HidDevice,
6580
tx: mpsc::Sender<UxUpdate>,
6681
) -> Result<HidChannel<'d>, Error> {
82+
let (handle_tx, handle_rx) = mpsc::channel(1);
83+
let handle = HidChannelHandle { tx: handle_tx };
84+
6785
let mut channel = Self {
6886
status: ChannelStatus::Ready,
6987
device,
7088
open_device: match device.backend {
7189
HidBackendDevice::HidApiDevice(_) => {
7290
let hidapi_device = Self::hid_open(device)?;
73-
OpenHidDevice::HidApiDevice(Arc::new(Mutex::new(hidapi_device)))
91+
OpenHidDevice::HidApiDevice(Arc::new(Mutex::new((hidapi_device, handle_rx))))
7492
}
7593
#[cfg(feature = "virtual-hid-device")]
7694
HidBackendDevice::VirtualDevice(_) => OpenHidDevice::VirtualDevice,
7795
},
7896
init: InitResponse::default(),
7997
auth_token_data: None,
8098
tx,
99+
handle,
81100
};
82101
channel.init = channel.init(INIT_TIMEOUT).await?;
83102
Ok(channel)
84103
}
85104

105+
pub fn get_handle(&self) -> HidChannelHandle {
106+
self.handle.clone()
107+
}
108+
86109
#[instrument(skip_all)]
87110
pub async fn wink(&mut self, timeout: Duration) -> Result<bool, Error> {
88111
if !self.init.caps.contains(Caps::WINK) {
@@ -246,22 +269,40 @@ impl<'d> HidChannel<'d> {
246269
pub async fn hid_send(&self, msg: &HidMessage) -> Result<(), Error> {
247270
match &self.open_device {
248271
OpenHidDevice::HidApiDevice(hidapi_device) => {
249-
let Ok(guard) = hidapi_device.lock() else {
272+
let Ok(mut guard) = hidapi_device.lock() else {
250273
warn!("Poisoned lock on HID API device");
251274
return Err(Error::Transport(TransportError::ConnectionLost));
252275
};
253-
Self::hid_send_hidapi(guard.deref(), msg)
276+
let (device, cancel_rx) = guard.deref_mut();
277+
let response = Self::hid_send_hidapi(device, cancel_rx, msg);
278+
if matches!(response, Err(Error::Platform(PlatformError::Cancelled))) {
279+
// Using hid_send_hidapi directly, instead of hid_cancel, to avoid recursion
280+
let _ = Self::hid_send_hidapi(
281+
device,
282+
cancel_rx,
283+
&HidMessage::new(self.init.cid, HidCommand::Cancel, &[]),
284+
);
285+
}
286+
response
254287
}
255288
#[cfg(feature = "virtual-hid-device")]
256289
OpenHidDevice::VirtualDevice => Self::hid_send_virtual(msg).await,
257290
}
258291
}
259292

260-
fn hid_send_hidapi(device: &hidapi::HidDevice, msg: &HidMessage) -> Result<(), Error> {
293+
fn hid_send_hidapi(
294+
device: &hidapi::HidDevice,
295+
cancel_rx: &mut Receiver<CancelHidOperation>,
296+
msg: &HidMessage,
297+
) -> Result<(), Error> {
261298
let packets = msg
262299
.packets(PACKET_SIZE)
263300
.or(Err(Error::Transport(TransportError::InvalidFraming)))?;
264301
for (i, packet) in packets.iter().enumerate() {
302+
if !matches!(cancel_rx.try_recv(), Err(TryRecvError::Empty)) {
303+
return Err(Error::Platform(PlatformError::Cancelled));
304+
}
305+
265306
let mut report: Vec<u8> = vec![REPORT_ID];
266307
report.extend(packet);
267308
report.extend(vec![0; PACKET_SIZE - packet.len()]);
@@ -319,13 +360,15 @@ impl<'d> HidChannel<'d> {
319360
// Note that we're just using spawn_blocking() on hid_recv(), not on hid_send(),
320361
// since implementing this on hid_send and would cause unnecessary copies/locking.
321362
tokio::task::spawn_blocking(move || {
322-
let Ok(guard) = device.lock() else {
363+
let Ok(mut guard) = device.lock() else {
323364
warn!("Poisoned lock on HID API device");
324365
return Err(Error::Transport(TransportError::ConnectionLost));
325366
};
326-
let device = guard.deref();
327-
Self::hid_recv_hidapi(device, timeout)
328-
}).await.unwrap()
367+
let (device, cancel_rx) = guard.deref_mut();
368+
Self::hid_recv_hidapi(device, cancel_rx, timeout)
369+
})
370+
.await
371+
.expect("HID read not to panic.")
329372
}
330373
#[cfg(feature = "virtual-hid-device")]
331374
OpenHidDevice::VirtualDevice => Self::hid_recv_virtual(timeout).await,
@@ -339,14 +382,26 @@ impl<'d> HidChannel<'d> {
339382
debug!("Ignoring HID keep-alive");
340383
continue;
341384
}
385+
Err(Error::Platform(PlatformError::Cancelled)) => {
386+
let _ = self.hid_cancel().await;
387+
break response;
388+
}
342389
_ => break response,
343390
}
344391
}
345392
}
346393

347-
fn hid_recv_hidapi(device: &hidapi::HidDevice, timeout: Duration) -> Result<HidMessage, Error> {
394+
fn hid_recv_hidapi(
395+
device: &hidapi::HidDevice,
396+
cancel_rx: &mut Receiver<CancelHidOperation>,
397+
timeout: Duration,
398+
) -> Result<HidMessage, Error> {
348399
let mut parser = HidMessageParser::new();
349400
loop {
401+
if !matches!(cancel_rx.try_recv(), Err(TryRecvError::Empty)) {
402+
return Err(Error::Platform(PlatformError::Cancelled));
403+
}
404+
350405
let mut report = [0; PACKET_SIZE];
351406
device
352407
.read_timeout(&mut report, timeout.as_millis() as i32)

0 commit comments

Comments
 (0)