11//! WebSocket tunnel-server transport for the caBLE hybrid protocol.
22use sha2:: { Digest , Sha256 } ;
33use tokio:: net:: TcpStream ;
4- use tokio_tungstenite:: tungstenite:: http:: StatusCode ;
4+ use tokio_tungstenite:: tungstenite:: handshake:: client:: Request ;
5+ use tokio_tungstenite:: tungstenite:: http:: { header:: LOCATION , StatusCode } ;
6+ use tokio_tungstenite:: tungstenite:: Error as TungsteniteError ;
57use tokio_tungstenite:: { connect_async, MaybeTlsStream , WebSocketStream } ;
68use tracing:: { debug, error, trace} ;
79use tungstenite:: client:: IntoClientRequest ;
10+ use url:: Url ;
811
12+ use super :: known_devices:: CableKnownDeviceId ;
913use super :: protocol:: CableTunnelConnectionType ;
1014use crate :: proto:: ctap2:: cbor;
1115use crate :: transport:: error:: TransportError ;
1216
17+ const MAX_TUNNEL_REDIRECTS : usize = 5 ;
18+
1319fn ensure_rustls_crypto_provider ( ) {
1420 use std:: sync:: Once ;
1521 static RUSTLS_INIT : Once = Once :: new ( ) ;
@@ -55,13 +61,77 @@ pub fn decode_tunnel_server_domain(encoded: u16) -> Option<String> {
5561 Some ( ret)
5662}
5763
64+ /// Builds the tunnel request, re-attaching the fido.cable and client-payload headers.
65+ pub ( crate ) fn build_tunnel_request (
66+ url : & str ,
67+ connection_type : & CableTunnelConnectionType ,
68+ ) -> Result < Request , TransportError > {
69+ let mut request = url
70+ . into_client_request ( )
71+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
72+ let headers = request. headers_mut ( ) ;
73+ headers. insert (
74+ "Sec-WebSocket-Protocol" ,
75+ "fido.cable"
76+ . parse ( )
77+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
78+ ) ;
79+
80+ if let CableTunnelConnectionType :: KnownDevice { client_payload, .. } = connection_type {
81+ let client_payload =
82+ cbor:: to_vec ( client_payload) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
83+ headers. insert (
84+ "X-caBLE-Client-Payload" ,
85+ hex:: encode ( client_payload)
86+ . parse ( )
87+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
88+ ) ;
89+ }
90+ Ok ( request)
91+ }
92+
93+ /// Resolves a redirect Location, which may be relative, against the current URL.
94+ fn resolve_redirect_target ( base : & str , location : & str ) -> Result < String , TransportError > {
95+ let base = Url :: parse ( base) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
96+ let target = base
97+ . join ( location)
98+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
99+ Ok ( target. to_string ( ) )
100+ }
101+
102+ /// Maps a non-101 tunnel handshake status to a transport error, distinguishing 410 Gone.
103+ fn tunnel_status_error ( status : StatusCode ) -> TransportError {
104+ if status == StatusCode :: GONE {
105+ TransportError :: TunnelServerGone
106+ } else {
107+ TransportError :: ConnectionFailed
108+ }
109+ }
110+
111+ /// The known-device id to forget on a 410 Gone, for a known-device connection.
112+ pub ( crate ) fn known_device_id_to_forget (
113+ error : & TransportError ,
114+ connection_type : & CableTunnelConnectionType ,
115+ ) -> Option < CableKnownDeviceId > {
116+ match ( error, connection_type) {
117+ (
118+ TransportError :: TunnelServerGone ,
119+ CableTunnelConnectionType :: KnownDevice {
120+ authenticator_public_key,
121+ ..
122+ } ,
123+ ) => Some ( hex:: encode ( authenticator_public_key) ) ,
124+ _ => None ,
125+ }
126+ }
127+
58128pub ( crate ) async fn connect (
59129 tunnel_domain : & str ,
60130 connection_type : & CableTunnelConnectionType ,
61131) -> Result < WebSocketStream < MaybeTlsStream < TcpStream > > , TransportError > {
62132 ensure_rustls_crypto_provider ( ) ;
63133
64- let connect_url = match connection_type {
134+ let mut connect_url = match connection_type {
65135 CableTunnelConnectionType :: QrCode {
66136 routing_id,
67137 tunnel_id,
@@ -74,50 +144,81 @@ pub(crate) async fn connect(
74144 format ! ( "wss://{}/cable/contact/{}" , tunnel_domain, contact_id)
75145 }
76146 } ;
77- debug ! ( ?connect_url, "Connecting to tunnel server" ) ;
78- let mut request = connect_url
79- . into_client_request ( )
80- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
81- request. headers_mut ( ) . insert (
82- "Sec-WebSocket-Protocol" ,
83- "fido.cable"
84- . parse ( )
85- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
86- ) ;
87147
88- if let CableTunnelConnectionType :: KnownDevice { client_payload, .. } = connection_type {
89- let client_payload =
90- cbor:: to_vec ( client_payload) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
91- request. headers_mut ( ) . insert (
92- "X-caBLE-Client-Payload" ,
93- hex:: encode ( client_payload)
94- . parse ( )
95- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
96- ) ;
97- }
98- trace ! ( ?request) ;
148+ for _ in 0 ..=MAX_TUNNEL_REDIRECTS {
149+ debug ! ( ?connect_url, "Connecting to tunnel server" ) ;
150+ let request = build_tunnel_request ( & connect_url, connection_type) ?;
151+ trace ! ( ?request) ;
152+
153+ let error = match connect_async ( request) . await {
154+ Ok ( ( ws_stream, response) ) => {
155+ debug ! ( ?response, "Connected to tunnel server" ) ;
156+ if response. status ( ) != StatusCode :: SWITCHING_PROTOCOLS {
157+ error ! ( ?response, "Failed to switch to websocket protocol" ) ;
158+ return Err ( TransportError :: ConnectionFailed ) ;
159+ }
160+ debug ! ( "Tunnel server returned success" ) ;
161+ return Ok ( ws_stream) ;
162+ }
163+ Err ( error) => error,
164+ } ;
99165
100- let ( ws_stream, response) = match connect_async ( request) . await {
101- Ok ( ( ws_stream, response) ) => ( ws_stream, response) ,
102- Err ( e) => {
103- error ! ( ?e, "Failed to connect to tunnel server" ) ;
166+ let TungsteniteError :: Http ( response) = error else {
167+ error ! ( ?error, "Failed to connect to tunnel server" ) ;
104168 return Err ( TransportError :: ConnectionFailed ) ;
169+ } ;
170+
171+ let status = response. status ( ) ;
172+ if status. is_redirection ( ) {
173+ let Some ( location) = response
174+ . headers ( )
175+ . get ( LOCATION )
176+ . and_then ( |value| value. to_str ( ) . ok ( ) )
177+ else {
178+ error ! ( ?status, "Tunnel redirect missing a usable Location header" ) ;
179+ return Err ( TransportError :: ConnectionFailed ) ;
180+ } ;
181+ connect_url = resolve_redirect_target ( & connect_url, location) ?;
182+ debug ! ( ?connect_url, "Following tunnel redirect" ) ;
183+ continue ;
105184 }
106- } ;
107- debug ! ( ?response, "Connected to tunnel server" ) ;
108185
109- if response. status ( ) != StatusCode :: SWITCHING_PROTOCOLS {
110- error ! ( ?response, "Failed to switch to websocket protocol" ) ;
111- return Err ( TransportError :: ConnectionFailed ) ;
186+ error ! ( ?status, "Tunnel server rejected the connection" ) ;
187+ return Err ( tunnel_status_error ( status) ) ;
112188 }
113- debug ! ( "Tunnel server returned success" ) ;
114189
115- Ok ( ws_stream)
190+ error ! ( "Exceeded the maximum number of tunnel redirects" ) ;
191+ Err ( TransportError :: ConnectionFailed )
116192}
117193
118194#[ cfg( test) ]
119195mod tests {
120196 use super :: * ;
197+ use crate :: transport:: cable:: known_devices:: { ClientPayload , ClientPayloadHint } ;
198+ use p256:: NonZeroScalar ;
199+ use rand:: rngs:: OsRng ;
200+ use serde_bytes:: ByteBuf ;
201+
202+ fn known_device_connection_type ( public_key : Vec < u8 > ) -> CableTunnelConnectionType {
203+ CableTunnelConnectionType :: KnownDevice {
204+ contact_id : "contact-id" . to_string ( ) ,
205+ authenticator_public_key : public_key,
206+ client_payload : ClientPayload {
207+ link_id : ByteBuf :: from ( vec ! [ 1u8 ; 8 ] ) ,
208+ client_nonce : ByteBuf :: from ( vec ! [ 2u8 ; 16 ] ) ,
209+ hint : ClientPayloadHint :: GetAssertion ,
210+ } ,
211+ }
212+ }
213+
214+ fn qr_connection_type ( ) -> CableTunnelConnectionType {
215+ CableTunnelConnectionType :: QrCode {
216+ routing_id : "aabbcc" . to_string ( ) ,
217+ tunnel_id : "00112233445566778899aabbccddeeff" . to_string ( ) ,
218+ private_key : NonZeroScalar :: random ( & mut OsRng ) ,
219+ }
220+ }
221+
121222 #[ test]
122223 fn decode_tunnel_server_domain_known ( ) {
123224 assert_eq ! (
@@ -130,5 +231,96 @@ mod tests {
130231 ) ;
131232 }
132233
133- // TODO: test the non-known case
234+ #[ test]
235+ fn resolve_redirect_target_relative_and_absolute ( ) {
236+ let base = "wss://cable.example.com/cable/contact/abc" ;
237+ assert_eq ! (
238+ resolve_redirect_target( base, "/cable/contact/v2/abc" ) . unwrap( ) ,
239+ "wss://cable.example.com/cable/contact/v2/abc"
240+ ) ;
241+ assert_eq ! (
242+ resolve_redirect_target( base, "wss://cable.example.net/cable/contact/xyz" ) . unwrap( ) ,
243+ "wss://cable.example.net/cable/contact/xyz"
244+ ) ;
245+ }
246+
247+ #[ test]
248+ fn build_tunnel_request_reattaches_headers_for_known_device ( ) {
249+ let connection_type = known_device_connection_type ( vec ! [ 4u8 ; 65 ] ) ;
250+ let request = build_tunnel_request (
251+ "wss://cable.example.com/cable/contact/abc" ,
252+ & connection_type,
253+ )
254+ . unwrap ( ) ;
255+ assert_eq ! (
256+ request
257+ . headers( )
258+ . get( "Sec-WebSocket-Protocol" )
259+ . unwrap( )
260+ . to_str( )
261+ . unwrap( ) ,
262+ "fido.cable"
263+ ) ;
264+ assert ! ( request. headers( ) . get( "X-caBLE-Client-Payload" ) . is_some( ) ) ;
265+ }
266+
267+ #[ test]
268+ fn build_tunnel_request_omits_payload_for_qr_code ( ) {
269+ let connection_type = qr_connection_type ( ) ;
270+ let request = build_tunnel_request (
271+ "wss://cable.example.com/cable/connect/aabbcc/0011" ,
272+ & connection_type,
273+ )
274+ . unwrap ( ) ;
275+ assert_eq ! (
276+ request
277+ . headers( )
278+ . get( "Sec-WebSocket-Protocol" )
279+ . unwrap( )
280+ . to_str( )
281+ . unwrap( ) ,
282+ "fido.cable"
283+ ) ;
284+ assert ! ( request. headers( ) . get( "X-caBLE-Client-Payload" ) . is_none( ) ) ;
285+ }
286+
287+ #[ test]
288+ fn gone_forgets_known_device ( ) {
289+ let public_key = vec ! [ 7u8 ; 65 ] ;
290+ let connection_type = known_device_connection_type ( public_key. clone ( ) ) ;
291+ assert_eq ! (
292+ known_device_id_to_forget( & TransportError :: TunnelServerGone , & connection_type) ,
293+ Some ( hex:: encode( & public_key) )
294+ ) ;
295+ }
296+
297+ #[ test]
298+ fn gone_does_not_forget_qr_code ( ) {
299+ let connection_type = qr_connection_type ( ) ;
300+ assert_eq ! (
301+ known_device_id_to_forget( & TransportError :: TunnelServerGone , & connection_type) ,
302+ None
303+ ) ;
304+ }
305+
306+ #[ test]
307+ fn non_gone_error_does_not_forget_known_device ( ) {
308+ let connection_type = known_device_connection_type ( vec ! [ 7u8 ; 65 ] ) ;
309+ assert_eq ! (
310+ known_device_id_to_forget( & TransportError :: ConnectionFailed , & connection_type) ,
311+ None
312+ ) ;
313+ }
314+
315+ #[ test]
316+ fn gone_status_maps_to_distinct_error ( ) {
317+ assert_eq ! (
318+ tunnel_status_error( StatusCode :: GONE ) ,
319+ TransportError :: TunnelServerGone
320+ ) ;
321+ assert_eq ! (
322+ tunnel_status_error( StatusCode :: BAD_GATEWAY ) ,
323+ TransportError :: ConnectionFailed
324+ ) ;
325+ }
134326}
0 commit comments