@@ -14,41 +14,36 @@ use libwebauthn::{
1414 self ,
1515 ops:: webauthn:: { GetAssertionResponse , MakeCredentialResponse } ,
1616} ;
17- use tokio:: sync:: {
18- mpsc:: { self , Receiver , Sender } ,
19- Mutex as AsyncMutex ,
20- } ;
17+ use tokio:: sync:: oneshot:: Sender ;
2118
2219use creds_lib:: {
23- client:: CredentialServiceClient ,
2420 model:: {
2521 CredentialRequest , CredentialResponse , Device , Error as CredentialServiceError , Operation ,
2622 Transport ,
2723 } ,
28- server:: { CreatePublicKeyCredentialRequest , ViewRequest } ,
24+ server:: ViewRequest ,
2925} ;
3026
3127use crate :: credential_service:: { hybrid:: HybridEvent , usb:: UsbEvent } ;
3228
3329use hybrid:: { HybridHandler , HybridState , HybridStateInternal } ;
3430use usb:: { UsbHandler , UsbStateInternal } ;
3531pub use {
36- server:: { CredentialManagementClient , InProcessServer , UiController } ,
32+ server:: { CredentialManagementClient , UiController } ,
3733 usb:: UsbState ,
3834} ;
3935
36+ type RequestContext = (
37+ CredentialRequest ,
38+ Sender < Result < CredentialResponse , CredentialServiceError > > ,
39+ ) ;
40+
4041#[ derive( Debug ) ]
4142pub struct CredentialService < H : HybridHandler , U : UsbHandler , UC : UiController > {
4243 devices : Vec < Device > ,
4344
44- cred_request : Mutex <
45- Option < (
46- CredentialRequest ,
47- Sender < Result < CredentialResponse , CredentialServiceError > > ,
48- ) > ,
49- > ,
50- // Place to store data to be returned to the caller
51- cred_response : Arc < Mutex < Option < CredentialResponse > > > ,
45+ /// Current request and channel to respond to caller.
46+ ctx : Arc < Mutex < Option < RequestContext > > > ,
5247
5348 hybrid_handler : H ,
5449 usb_handler : U ,
@@ -73,8 +68,7 @@ impl<H: HybridHandler + Debug, U: UsbHandler + Debug, UC: UiController + Debug>
7368 Self {
7469 devices,
7570
76- cred_request : Mutex :: new ( None ) ,
77- cred_response : Arc :: new ( Mutex :: new ( None ) ) ,
71+ ctx : Arc :: new ( Mutex :: new ( None ) ) ,
7872
7973 hybrid_handler,
8074 usb_handler,
@@ -88,25 +82,18 @@ impl<H: HybridHandler + Debug, U: UsbHandler + Debug, UC: UiController + Debug>
8882 request : & CredentialRequest ,
8983 tx : Sender < Result < CredentialResponse , CredentialServiceError > > ,
9084 ) {
91- // let (tx, rx) = mpsc::channel(1);
92- let res = {
93- let mut cred_request = self . cred_request . lock ( ) . unwrap ( ) ;
85+ {
86+ let mut cred_request = self . ctx . lock ( ) . unwrap ( ) ;
9487 if cred_request. is_some ( ) {
95- drop ( cred_request) ;
96- false
88+ tx. send ( Err ( CredentialServiceError :: Internal (
89+ "Already a request in progress." . to_string ( ) ,
90+ ) ) )
91+ . expect ( "Send to local receiver to succeed" ) ;
92+ return ;
9793 } else {
98- _ = cred_request. insert ( ( request. clone ( ) , tx. clone ( ) ) ) ;
99- true
94+ _ = cred_request. insert ( ( request. clone ( ) , tx) ) ;
10095 }
10196 } ;
102- if !res {
103- tx. send ( Err ( CredentialServiceError :: Internal (
104- "Already a request in progress." . to_string ( ) ,
105- ) ) )
106- . await
107- . expect ( "Send to local receiver to succeed" ) ;
108- return ;
109- }
11097 let operation = match & request {
11198 CredentialRequest :: CreatePublicKeyCredentialRequest ( _) => Operation :: Create ,
11299 CredentialRequest :: GetPublicKeyCredentialRequest ( _) => Operation :: Get ,
@@ -120,9 +107,10 @@ impl<H: HybridHandler + Debug, U: UsbHandler + Debug, UC: UiController + Debug>
120107 . map_err ( |err| err. to_string ( ) ) ;
121108 if let Err ( err) = launch_ui_response {
122109 tracing:: error!( "Failed to launch UI for credentials: {err}. Cancelling request." ) ;
123- _ = self . cred_request . lock ( ) . unwrap ( ) . take ( ) ;
110+ _ = self . ctx . lock ( ) . unwrap ( ) . take ( ) ;
124111 let err = Err ( CredentialServiceError :: Internal ( err) ) ;
125- tx. send ( err) . await ;
112+ let ( _, tx) = self . ctx . lock ( ) . unwrap ( ) . take ( ) . unwrap ( ) ;
113+ tx. send ( err) ;
126114 }
127115 }
128116
@@ -133,32 +121,39 @@ impl<H: HybridHandler + Debug, U: UsbHandler + Debug, UC: UiController + Debug>
133121 pub fn get_hybrid_credential (
134122 & self ,
135123 ) -> Pin < Box < dyn Stream < Item = HybridState > + Send + ' static > > {
136- let guard = self . cred_request . lock ( ) . unwrap ( ) ;
137- let cred_request = guard. clone ( ) . unwrap ( ) ;
138- let stream = self . hybrid_handler . start ( & cred_request. 0 ) ;
139- let cred_response = self . cred_response . clone ( ) ;
140- Box :: pin ( HybridStateStream {
141- inner : stream,
142- cred_response,
143- } )
124+ let guard = self . ctx . lock ( ) . unwrap ( ) ;
125+ if let Some ( ( ref cred_request, _) ) = * guard {
126+ let stream = self . hybrid_handler . start ( & cred_request) ;
127+ let ctx = self . ctx . clone ( ) ;
128+ Box :: pin ( HybridStateStream { inner : stream, ctx } )
129+ } else {
130+ tracing:: error!(
131+ "Attempted to start hybrid credential flow, but no request context was found."
132+ ) ;
133+ todo ! ( "Handle error when context is not set up." )
134+ }
144135 }
145136
146137 pub fn get_usb_credential ( & self ) -> Pin < Box < dyn Stream < Item = UsbState > + Send + ' static > > {
147- let guard = self . cred_request . lock ( ) . unwrap ( ) ;
148- let cred_request = guard. clone ( ) . unwrap ( ) ;
149- let stream = self . usb_handler . start ( & cred_request. 0 ) ;
150- Box :: pin ( UsbStateStream {
151- inner : stream,
152- cred_response : self . cred_response . clone ( ) ,
153- } )
138+ let guard = self . ctx . lock ( ) . unwrap ( ) ;
139+ if let Some ( ( ref cred_request, _) ) = * guard {
140+ let stream = self . usb_handler . start ( & cred_request) ;
141+ let ctx = self . ctx . clone ( ) ;
142+ Box :: pin ( UsbStateStream { inner : stream, ctx } )
143+ } else {
144+ tracing:: error!(
145+ "Attempted to start hybrid credential flow, but no request context was found."
146+ ) ;
147+ todo ! ( "Handle error when context is not set up." )
148+ }
154149 }
155150
156151 pub async fn complete_auth (
157152 & self ,
158153 response : Result < CredentialResponse , CredentialServiceError > ,
159154 ) -> ( ) {
160- if let Some ( ( _request, responder) ) = self . cred_request . lock ( ) . unwrap ( ) . take ( ) {
161- if responder. send ( response) . await . is_err ( ) {
155+ if let Some ( ( _request, responder) ) = self . ctx . lock ( ) . unwrap ( ) . take ( ) {
156+ if responder. send ( response) . is_err ( ) {
162157 tracing:: error!( "Failed to send response to back to caller" ) ;
163158 } ;
164159 } else {
@@ -169,7 +164,7 @@ impl<H: HybridHandler + Debug, U: UsbHandler + Debug, UC: UiController + Debug>
169164
170165pub struct HybridStateStream < H > {
171166 inner : H ,
172- cred_response : Arc < Mutex < Option < CredentialResponse > > > ,
167+ ctx : Arc < Mutex < Option < RequestContext > > > ,
173168}
174169
175170impl < H > Stream for HybridStateStream < H >
@@ -182,7 +177,7 @@ where
182177 self : Pin < & mut Self > ,
183178 cx : & mut std:: task:: Context < ' _ > ,
184179 ) -> Poll < Option < Self :: Item > > {
185- let cred_response = & self . cred_response . clone ( ) ;
180+ let ctx = & self . ctx . clone ( ) ;
186181 match Box :: pin ( Box :: pin ( self ) . as_mut ( ) . inner . next ( ) ) . poll ( cx) {
187182 Poll :: Pending => Poll :: Pending ,
188183 Poll :: Ready ( Some ( HybridEvent { state } ) ) => {
@@ -206,8 +201,7 @@ where
206201 )
207202 }
208203 } ;
209- let mut cred_response = cred_response. lock ( ) . unwrap ( ) ;
210- cred_response. replace ( response) ;
204+ complete_request ( ctx, response. clone ( ) ) ;
211205 }
212206 Poll :: Ready ( Some ( state. into ( ) ) )
213207 }
@@ -218,7 +212,7 @@ where
218212
219213struct UsbStateStream < H > {
220214 inner : H ,
221- cred_response : Arc < Mutex < Option < CredentialResponse > > > ,
215+ ctx : Arc < Mutex < Option < RequestContext > > > ,
222216}
223217
224218impl < H > Stream for UsbStateStream < H >
@@ -231,13 +225,12 @@ where
231225 self : Pin < & mut Self > ,
232226 cx : & mut std:: task:: Context < ' _ > ,
233227 ) -> Poll < Option < Self :: Item > > {
234- let cred_response = & self . cred_response . clone ( ) ;
228+ let ctx = & self . ctx . clone ( ) ;
235229 match Box :: pin ( Box :: pin ( self ) . as_mut ( ) . inner . next ( ) ) . poll ( cx) {
236230 Poll :: Pending => Poll :: Pending ,
237231 Poll :: Ready ( Some ( UsbEvent { state } ) ) => {
238232 if let UsbStateInternal :: Completed ( response) = & state {
239- let mut cred_response = cred_response. lock ( ) . unwrap ( ) ;
240- cred_response. replace ( response. clone ( ) ) ;
233+ complete_request ( ctx, response. clone ( ) ) ;
241234 }
242235 Poll :: Ready ( Some ( state. into ( ) ) )
243236 }
@@ -246,6 +239,18 @@ where
246239 }
247240}
248241
242+ fn complete_request ( ctx : & Mutex < Option < RequestContext > > , response : CredentialResponse ) {
243+ if let Some ( ( _, responder) ) = ctx. lock ( ) . unwrap ( ) . take ( ) {
244+ if responder. send ( Ok ( response) ) . is_err ( ) {
245+ tracing:: error!(
246+ "Attempted to send credential response to caller, but channel was closed."
247+ ) ;
248+ }
249+ } else {
250+ tracing:: error!( "Tried to consume context to respond to caller, but none was found." )
251+ }
252+ }
253+
249254#[ derive( Debug , Clone ) ]
250255enum AuthenticatorResponse {
251256 CredentialCreated ( MakeCredentialResponse ) ,
0 commit comments