11use std:: convert:: TryFrom ;
22use std:: fmt:: { Debug , Display , Formatter } ;
33use std:: io:: { Cursor as IOCursor , Seek , SeekFrom } ;
4- use std:: ops:: Deref ;
4+ use std:: ops:: DerefMut ;
55use std:: sync:: { Arc , Mutex } ;
66use std:: time:: Duration ;
77
88use async_trait:: async_trait;
99use byteorder:: { BigEndian , ReadBytesExt } ;
1010use hidapi:: HidDevice as HidApiDevice ;
1111use rand:: { thread_rng, Rng } ;
12- use tokio:: sync:: mpsc;
12+ use tokio:: sync:: mpsc:: error:: TryRecvError ;
13+ use tokio:: sync:: mpsc:: { self , Receiver , Sender } ;
1314use tokio:: time:: sleep;
1415use tracing:: { debug, info, instrument, trace, warn, Level } ;
1516
@@ -27,6 +28,7 @@ use crate::transport::error::{Error, TransportError};
2728use crate :: transport:: hid:: framing:: {
2829 HidCommand , HidMessage , HidMessageParser , HidMessageParserState ,
2930} ;
31+ use crate :: webauthn:: PlatformError ;
3032use crate :: UxUpdate ;
3133
3234use super :: device:: get_hidapi;
@@ -44,45 +46,66 @@ const REPORT_ID: u8 = 0x00;
4446// by a CBOR command, so we want to ensure we wait some time after winking.
4547const WINK_MIN_WAIT : Duration = Duration :: from_secs ( 2 ) ;
4648
49+ pub type CancelHidOperation = ( ) ;
4750enum OpenHidDevice {
48- HidApiDevice ( Arc < Mutex < HidApiDevice > > ) ,
51+ HidApiDevice ( Arc < Mutex < ( HidApiDevice , mpsc :: Receiver < CancelHidOperation > ) > > ) ,
4952 #[ cfg( feature = "virtual-hid-device" ) ]
5053 VirtualDevice ,
5154}
5255
56+ #[ derive( Debug , Clone ) ]
57+ pub struct HidChannelHandle {
58+ tx : Sender < CancelHidOperation > ,
59+ }
60+
61+ impl HidChannelHandle {
62+ pub async fn cancel_ongoing_operation ( & self ) {
63+ let _ = self . tx . send ( ( ) ) . await ;
64+ }
65+ }
66+
5367pub struct HidChannel < ' d > {
5468 status : ChannelStatus ,
5569 device : & ' d HidDevice ,
5670 open_device : OpenHidDevice ,
5771 init : InitResponse ,
5872 auth_token_data : Option < AuthTokenData > ,
5973 tx : mpsc:: Sender < UxUpdate > ,
74+ handle : HidChannelHandle ,
6075}
6176
6277impl < ' d > HidChannel < ' d > {
6378 pub async fn new (
6479 device : & ' d HidDevice ,
6580 tx : mpsc:: Sender < UxUpdate > ,
6681 ) -> Result < HidChannel < ' d > , Error > {
82+ let ( handle_tx, handle_rx) = mpsc:: channel ( 1 ) ;
83+ let handle = HidChannelHandle { tx : handle_tx } ;
84+
6785 let mut channel = Self {
6886 status : ChannelStatus :: Ready ,
6987 device,
7088 open_device : match device. backend {
7189 HidBackendDevice :: HidApiDevice ( _) => {
7290 let hidapi_device = Self :: hid_open ( device) ?;
73- OpenHidDevice :: HidApiDevice ( Arc :: new ( Mutex :: new ( hidapi_device) ) )
91+ OpenHidDevice :: HidApiDevice ( Arc :: new ( Mutex :: new ( ( hidapi_device, handle_rx ) ) ) )
7492 }
7593 #[ cfg( feature = "virtual-hid-device" ) ]
7694 HidBackendDevice :: VirtualDevice ( _) => OpenHidDevice :: VirtualDevice ,
7795 } ,
7896 init : InitResponse :: default ( ) ,
7997 auth_token_data : None ,
8098 tx,
99+ handle,
81100 } ;
82101 channel. init = channel. init ( INIT_TIMEOUT ) . await ?;
83102 Ok ( channel)
84103 }
85104
105+ pub fn get_handle ( & self ) -> HidChannelHandle {
106+ self . handle . clone ( )
107+ }
108+
86109 #[ instrument( skip_all) ]
87110 pub async fn wink ( & mut self , timeout : Duration ) -> Result < bool , Error > {
88111 if !self . init . caps . contains ( Caps :: WINK ) {
@@ -246,22 +269,40 @@ impl<'d> HidChannel<'d> {
246269 pub async fn hid_send ( & self , msg : & HidMessage ) -> Result < ( ) , Error > {
247270 match & self . open_device {
248271 OpenHidDevice :: HidApiDevice ( hidapi_device) => {
249- let Ok ( guard) = hidapi_device. lock ( ) else {
272+ let Ok ( mut guard) = hidapi_device. lock ( ) else {
250273 warn ! ( "Poisoned lock on HID API device" ) ;
251274 return Err ( Error :: Transport ( TransportError :: ConnectionLost ) ) ;
252275 } ;
253- Self :: hid_send_hidapi ( guard. deref ( ) , msg)
276+ let ( device, cancel_rx) = guard. deref_mut ( ) ;
277+ let response = Self :: hid_send_hidapi ( device, cancel_rx, msg) ;
278+ if matches ! ( response, Err ( Error :: Platform ( PlatformError :: Cancelled ) ) ) {
279+ // Using hid_send_hidapi directly, instead of hid_cancel, to avoid recursion
280+ let _ = Self :: hid_send_hidapi (
281+ device,
282+ cancel_rx,
283+ & HidMessage :: new ( self . init . cid , HidCommand :: Cancel , & [ ] ) ,
284+ ) ;
285+ }
286+ response
254287 }
255288 #[ cfg( feature = "virtual-hid-device" ) ]
256289 OpenHidDevice :: VirtualDevice => Self :: hid_send_virtual ( msg) . await ,
257290 }
258291 }
259292
260- fn hid_send_hidapi ( device : & hidapi:: HidDevice , msg : & HidMessage ) -> Result < ( ) , Error > {
293+ fn hid_send_hidapi (
294+ device : & hidapi:: HidDevice ,
295+ cancel_rx : & mut Receiver < CancelHidOperation > ,
296+ msg : & HidMessage ,
297+ ) -> Result < ( ) , Error > {
261298 let packets = msg
262299 . packets ( PACKET_SIZE )
263300 . or ( Err ( Error :: Transport ( TransportError :: InvalidFraming ) ) ) ?;
264301 for ( i, packet) in packets. iter ( ) . enumerate ( ) {
302+ if !matches ! ( cancel_rx. try_recv( ) , Err ( TryRecvError :: Empty ) ) {
303+ return Err ( Error :: Platform ( PlatformError :: Cancelled ) ) ;
304+ }
305+
265306 let mut report: Vec < u8 > = vec ! [ REPORT_ID ] ;
266307 report. extend ( packet) ;
267308 report. extend ( vec ! [ 0 ; PACKET_SIZE - packet. len( ) ] ) ;
@@ -319,13 +360,15 @@ impl<'d> HidChannel<'d> {
319360 // Note that we're just using spawn_blocking() on hid_recv(), not on hid_send(),
320361 // since implementing this on hid_send and would cause unnecessary copies/locking.
321362 tokio:: task:: spawn_blocking ( move || {
322- let Ok ( guard) = device. lock ( ) else {
363+ let Ok ( mut guard) = device. lock ( ) else {
323364 warn ! ( "Poisoned lock on HID API device" ) ;
324365 return Err ( Error :: Transport ( TransportError :: ConnectionLost ) ) ;
325366 } ;
326- let device = guard. deref ( ) ;
327- Self :: hid_recv_hidapi ( device, timeout)
328- } ) . await . unwrap ( )
367+ let ( device, cancel_rx) = guard. deref_mut ( ) ;
368+ Self :: hid_recv_hidapi ( device, cancel_rx, timeout)
369+ } )
370+ . await
371+ . expect ( "HID read not to panic." )
329372 }
330373 #[ cfg( feature = "virtual-hid-device" ) ]
331374 OpenHidDevice :: VirtualDevice => Self :: hid_recv_virtual ( timeout) . await ,
@@ -339,14 +382,26 @@ impl<'d> HidChannel<'d> {
339382 debug ! ( "Ignoring HID keep-alive" ) ;
340383 continue ;
341384 }
385+ Err ( Error :: Platform ( PlatformError :: Cancelled ) ) => {
386+ let _ = self . hid_cancel ( ) . await ;
387+ break response;
388+ }
342389 _ => break response,
343390 }
344391 }
345392 }
346393
347- fn hid_recv_hidapi ( device : & hidapi:: HidDevice , timeout : Duration ) -> Result < HidMessage , Error > {
394+ fn hid_recv_hidapi (
395+ device : & hidapi:: HidDevice ,
396+ cancel_rx : & mut Receiver < CancelHidOperation > ,
397+ timeout : Duration ,
398+ ) -> Result < HidMessage , Error > {
348399 let mut parser = HidMessageParser :: new ( ) ;
349400 loop {
401+ if !matches ! ( cancel_rx. try_recv( ) , Err ( TryRecvError :: Empty ) ) {
402+ return Err ( Error :: Platform ( PlatformError :: Cancelled ) ) ;
403+ }
404+
350405 let mut report = [ 0 ; PACKET_SIZE ] ;
351406 device
352407 . read_timeout ( & mut report, timeout. as_millis ( ) as i32 )
0 commit comments