@@ -7,7 +7,7 @@ use std::{
77use libwebauthn:: {
88 self ,
99 ops:: webauthn:: { GetAssertionResponse , MakeCredentialResponse } ,
10- transport:: Device as _,
10+ transport:: { hid :: HidDevice , Device as _} ,
1111 webauthn:: { Error as WebAuthnError , WebAuthn } ,
1212 UxUpdate ,
1313} ;
@@ -35,6 +35,8 @@ pub struct CredentialService {
3535 usb_state : AsyncArc < AsyncMutex < UsbState > > ,
3636 usb_uv_handler : UsbUvHandler ,
3737
38+ internal_hid_devices : Vec < HidDevice > ,
39+ chosen_hid_device : Option < HidDevice > ,
3840 internal_device_credentials : Vec < CredentialMetadata > ,
3941 internal_device_state : InternalDeviceState ,
4042 internal_pin_attempts_left : u32 ,
@@ -81,13 +83,15 @@ impl CredentialService {
8183 usb_state : usb_state. clone ( ) ,
8284 usb_uv_handler : UsbUvHandler :: new ( ) ,
8385
86+ internal_hid_devices : Vec :: new ( ) ,
8487 internal_device_credentials,
8588 internal_device_state : InternalDeviceState :: Idle ,
8689 internal_pin_attempts_left : 5 ,
8790 internal_pin_unlock_time : None ,
8891
8992 cred_request,
9093 cred_response,
94+ chosen_hid_device : None ,
9195 }
9296 }
9397
@@ -97,32 +101,73 @@ impl CredentialService {
97101
98102 pub ( crate ) async fn poll_device_discovery_usb ( & mut self ) -> Result < UsbState , String > {
99103 debug ! ( "polling for USB status" ) ;
100- let prev_usb_state = self . usb_state . lock ( ) . await . clone ( ) ;
104+ let prev_usb_state = * self . usb_state . lock ( ) . await ;
101105 let next_usb_state = match prev_usb_state {
102106 UsbState :: Idle | UsbState :: Waiting => {
103- let devices = libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
104- if devices. is_empty ( ) {
107+ self . internal_hid_devices =
108+ libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
109+ if self . internal_hid_devices . is_empty ( ) {
105110 let state = UsbState :: Waiting ;
106111 * self . usb_state . lock ( ) . await = state;
107112 return Ok ( state) ;
108- }
109- if devices. is_empty ( ) {
110- Ok ( UsbState :: Waiting )
111- } else {
113+ } else if self . internal_hid_devices . len ( ) == 1 {
114+ self . chosen_hid_device = Some ( self . internal_hid_devices . swap_remove ( 0 ) ) ;
112115 Ok ( UsbState :: Connected )
116+ } else {
117+ Ok ( UsbState :: SelectingDevice )
113118 }
114119 }
120+ UsbState :: SelectingDevice => {
121+ let ( blinking_tx, mut blinking_rx) = tokio:: sync:: mpsc:: channel :: < Option < HidDevice > > (
122+ self . internal_hid_devices . len ( ) ,
123+ ) ;
124+ let mut expected_answers = self . internal_hid_devices . len ( ) ;
125+ for mut device in self . internal_hid_devices . drain ( ..) {
126+ let tx = blinking_tx. clone ( ) ;
127+ tokio ( ) . spawn ( async move {
128+ let ( mut channel, _state_rx) = device. channel ( ) . await . unwrap ( ) ;
129+ let res = channel
130+ . blink_and_wait_for_user_presence ( Duration :: from_secs ( 300 ) )
131+ . await ;
132+ drop ( channel) ;
133+ match res {
134+ Ok ( true ) => {
135+ let _ = tx. send ( Some ( device) ) . await ;
136+ }
137+ Ok ( false ) | Err ( _) => {
138+ let _ = tx. send ( None ) . await ;
139+ }
140+ }
141+ } ) ;
142+ }
143+ let mut state = UsbState :: Idle ;
144+ while let Some ( msg) = blinking_rx. recv ( ) . await {
145+ expected_answers -= 1 ;
146+ match msg {
147+ Some ( device) => {
148+ self . chosen_hid_device = Some ( device) ;
149+ state = UsbState :: Connected ;
150+ break ;
151+ }
152+ None => {
153+ if expected_answers == 0 {
154+ break ;
155+ } else {
156+ continue ;
157+ }
158+ }
159+ }
160+ }
161+ Ok ( state)
162+ }
115163 UsbState :: Connected => {
116- // TODO: I'm not sure how we want to handle multiple usb devices
117- // just take the first one found for now.
118- // TODO: store this device reference, perhaps in the enum itself
119164 let handler = self . usb_uv_handler . clone ( ) ;
120165 let cred_request = self . cred_request . clone ( ) ;
121166 let signal_tx = self . usb_uv_handler . signal_tx . clone ( ) ;
122167 let pin_rx = self . usb_uv_handler . pin_rx . clone ( ) ;
168+ let mut device = self . chosen_hid_device . take ( ) . unwrap ( ) ;
169+ self . internal_hid_devices . clear ( ) ;
123170 tokio ( ) . spawn ( async move {
124- let mut devices = libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
125- let device = devices. first_mut ( ) . unwrap ( ) ;
126171 let ( mut channel, state_rx) = device. channel ( ) . await . unwrap ( ) ;
127172 tokio ( ) . spawn ( async move {
128173 handle_usb_updates ( signal_tx, pin_rx, state_rx) . await ;
@@ -288,7 +333,7 @@ impl CredentialService {
288333 UsbState :: UserCancelled => Ok ( prev_usb_state) ,
289334 } ?;
290335
291- * self . usb_state . lock ( ) . await = next_usb_state;
336+ * self . usb_state . lock ( ) . await = next_usb_state. clone ( ) ;
292337 Ok ( next_usb_state)
293338 }
294339
@@ -402,7 +447,6 @@ impl CredentialService {
402447 }
403448}
404449
405-
406450#[ derive( Copy , Clone , Debug , Default , PartialEq ) ]
407451pub enum UsbState {
408452 /// Not polling for FIDO USB device.
@@ -433,6 +477,10 @@ pub enum UsbState {
433477
434478 // This isn't actually sent from the server.
435479 UserCancelled ,
480+
481+ // When we encounter multiple devices, we let all of them blink and continue
482+ // with the one that was tapped.
483+ SelectingDevice ,
436484}
437485
438486#[ derive( Clone , Debug , Default , PartialEq ) ]
0 commit comments