11//! WebSocket tunnel-server transport for the caBLE hybrid protocol.
22use sha2:: { Digest , Sha256 } ;
33use tokio:: net:: TcpStream ;
4- use tokio_tungstenite:: tungstenite:: http:: StatusCode ;
4+ use tokio_tungstenite:: tungstenite:: handshake:: client:: Request ;
5+ use tokio_tungstenite:: tungstenite:: http:: { header:: LOCATION , StatusCode } ;
6+ use tokio_tungstenite:: tungstenite:: Error as TungsteniteError ;
57use tokio_tungstenite:: { connect_async, MaybeTlsStream , WebSocketStream } ;
68use tracing:: { debug, error, trace} ;
79use tungstenite:: client:: IntoClientRequest ;
10+ use url:: Url ;
811
12+ use super :: error:: CableTunnelError ;
13+ use super :: known_devices:: CableKnownDeviceId ;
914use super :: protocol:: CableTunnelConnectionType ;
1015use crate :: proto:: ctap2:: cbor;
1116use crate :: transport:: error:: TransportError ;
1217
18+ const MAX_TUNNEL_REDIRECTS : usize = 5 ;
19+
1320fn ensure_rustls_crypto_provider ( ) {
1421 use std:: sync:: Once ;
1522 static RUSTLS_INIT : Once = Once :: new ( ) ;
@@ -55,13 +62,77 @@ pub fn decode_tunnel_server_domain(encoded: u16) -> Option<String> {
5562 Some ( ret)
5663}
5764
65+ /// Builds the tunnel request, re-attaching the fido.cable and client-payload headers.
66+ pub ( crate ) fn build_tunnel_request (
67+ url : & str ,
68+ connection_type : & CableTunnelConnectionType ,
69+ ) -> Result < Request , TransportError > {
70+ let mut request = url
71+ . into_client_request ( )
72+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
73+ let headers = request. headers_mut ( ) ;
74+ headers. insert (
75+ "Sec-WebSocket-Protocol" ,
76+ "fido.cable"
77+ . parse ( )
78+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
79+ ) ;
80+
81+ if let CableTunnelConnectionType :: KnownDevice { client_payload, .. } = connection_type {
82+ let client_payload =
83+ cbor:: to_vec ( client_payload) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
84+ headers. insert (
85+ "X-caBLE-Client-Payload" ,
86+ hex:: encode ( client_payload)
87+ . parse ( )
88+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
89+ ) ;
90+ }
91+ Ok ( request)
92+ }
93+
94+ /// Resolves a redirect Location, which may be relative, against the current URL.
95+ fn resolve_redirect_target ( base : & str , location : & str ) -> Result < String , TransportError > {
96+ let base = Url :: parse ( base) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
97+ let target = base
98+ . join ( location)
99+ . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
100+ Ok ( target. to_string ( ) )
101+ }
102+
103+ /// Maps a non-101 tunnel handshake status to a transport error, distinguishing 410 Gone.
104+ fn tunnel_status_error ( status : StatusCode ) -> TransportError {
105+ if status == StatusCode :: GONE {
106+ CableTunnelError :: Gone . into ( )
107+ } else {
108+ CableTunnelError :: UnexpectedStatus ( status. as_u16 ( ) ) . into ( )
109+ }
110+ }
111+
112+ /// The known-device id to forget on a 410 Gone, for a known-device connection.
113+ pub ( crate ) fn known_device_id_to_forget (
114+ error : & TransportError ,
115+ connection_type : & CableTunnelConnectionType ,
116+ ) -> Option < CableKnownDeviceId > {
117+ match ( error, connection_type) {
118+ (
119+ TransportError :: CableTunnel ( CableTunnelError :: Gone ) ,
120+ CableTunnelConnectionType :: KnownDevice {
121+ authenticator_public_key,
122+ ..
123+ } ,
124+ ) => Some ( hex:: encode ( authenticator_public_key) ) ,
125+ _ => None ,
126+ }
127+ }
128+
58129pub ( crate ) async fn connect (
59130 tunnel_domain : & str ,
60131 connection_type : & CableTunnelConnectionType ,
61132) -> Result < WebSocketStream < MaybeTlsStream < TcpStream > > , TransportError > {
62133 ensure_rustls_crypto_provider ( ) ;
63134
64- let connect_url = match connection_type {
135+ let mut connect_url = match connection_type {
65136 CableTunnelConnectionType :: QrCode {
66137 routing_id,
67138 tunnel_id,
@@ -74,50 +145,81 @@ pub(crate) async fn connect(
74145 format ! ( "wss://{}/cable/contact/{}" , tunnel_domain, contact_id)
75146 }
76147 } ;
77- debug ! ( ?connect_url, "Connecting to tunnel server" ) ;
78- let mut request = connect_url
79- . into_client_request ( )
80- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
81- request. headers_mut ( ) . insert (
82- "Sec-WebSocket-Protocol" ,
83- "fido.cable"
84- . parse ( )
85- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
86- ) ;
87148
88- if let CableTunnelConnectionType :: KnownDevice { client_payload, .. } = connection_type {
89- let client_payload =
90- cbor:: to_vec ( client_payload) . or ( Err ( TransportError :: InvalidEndpoint ) ) ?;
91- request. headers_mut ( ) . insert (
92- "X-caBLE-Client-Payload" ,
93- hex:: encode ( client_payload)
94- . parse ( )
95- . or ( Err ( TransportError :: InvalidEndpoint ) ) ?,
96- ) ;
97- }
98- trace ! ( ?request) ;
149+ for _ in 0 ..=MAX_TUNNEL_REDIRECTS {
150+ debug ! ( ?connect_url, "Connecting to tunnel server" ) ;
151+ let request = build_tunnel_request ( & connect_url, connection_type) ?;
152+ trace ! ( ?request) ;
153+
154+ let error = match connect_async ( request) . await {
155+ Ok ( ( ws_stream, response) ) => {
156+ debug ! ( ?response, "Connected to tunnel server" ) ;
157+ if response. status ( ) != StatusCode :: SWITCHING_PROTOCOLS {
158+ error ! ( ?response, "Failed to switch to websocket protocol" ) ;
159+ return Err ( TransportError :: ConnectionFailed ) ;
160+ }
161+ debug ! ( "Tunnel server returned success" ) ;
162+ return Ok ( ws_stream) ;
163+ }
164+ Err ( error) => error,
165+ } ;
99166
100- let ( ws_stream, response) = match connect_async ( request) . await {
101- Ok ( ( ws_stream, response) ) => ( ws_stream, response) ,
102- Err ( e) => {
103- error ! ( ?e, "Failed to connect to tunnel server" ) ;
167+ let TungsteniteError :: Http ( response) = error else {
168+ error ! ( ?error, "Failed to connect to tunnel server" ) ;
104169 return Err ( TransportError :: ConnectionFailed ) ;
170+ } ;
171+
172+ let status = response. status ( ) ;
173+ if status. is_redirection ( ) {
174+ let Some ( location) = response
175+ . headers ( )
176+ . get ( LOCATION )
177+ . and_then ( |value| value. to_str ( ) . ok ( ) )
178+ else {
179+ error ! ( ?status, "Tunnel redirect missing a usable Location header" ) ;
180+ return Err ( TransportError :: ConnectionFailed ) ;
181+ } ;
182+ connect_url = resolve_redirect_target ( & connect_url, location) ?;
183+ debug ! ( ?connect_url, "Following tunnel redirect" ) ;
184+ continue ;
105185 }
106- } ;
107- debug ! ( ?response, "Connected to tunnel server" ) ;
108186
109- if response. status ( ) != StatusCode :: SWITCHING_PROTOCOLS {
110- error ! ( ?response, "Failed to switch to websocket protocol" ) ;
111- return Err ( TransportError :: ConnectionFailed ) ;
187+ error ! ( ?status, "Tunnel server rejected the connection" ) ;
188+ return Err ( tunnel_status_error ( status) ) ;
112189 }
113- debug ! ( "Tunnel server returned success" ) ;
114190
115- Ok ( ws_stream)
191+ error ! ( "Exceeded the maximum number of tunnel redirects" ) ;
192+ Err ( CableTunnelError :: TooManyRedirects . into ( ) )
116193}
117194
118195#[ cfg( test) ]
119196mod tests {
120197 use super :: * ;
198+ use crate :: transport:: cable:: known_devices:: { ClientPayload , ClientPayloadHint } ;
199+ use p256:: NonZeroScalar ;
200+ use rand:: rngs:: OsRng ;
201+ use serde_bytes:: ByteBuf ;
202+
203+ fn known_device_connection_type ( public_key : Vec < u8 > ) -> CableTunnelConnectionType {
204+ CableTunnelConnectionType :: KnownDevice {
205+ contact_id : "contact-id" . to_string ( ) ,
206+ authenticator_public_key : public_key,
207+ client_payload : ClientPayload {
208+ link_id : ByteBuf :: from ( vec ! [ 1u8 ; 8 ] ) ,
209+ client_nonce : ByteBuf :: from ( vec ! [ 2u8 ; 16 ] ) ,
210+ hint : ClientPayloadHint :: GetAssertion ,
211+ } ,
212+ }
213+ }
214+
215+ fn qr_connection_type ( ) -> CableTunnelConnectionType {
216+ CableTunnelConnectionType :: QrCode {
217+ routing_id : "aabbcc" . to_string ( ) ,
218+ tunnel_id : "00112233445566778899aabbccddeeff" . to_string ( ) ,
219+ private_key : NonZeroScalar :: random ( & mut OsRng ) ,
220+ }
221+ }
222+
121223 #[ test]
122224 fn decode_tunnel_server_domain_known ( ) {
123225 assert_eq ! (
@@ -130,5 +232,102 @@ mod tests {
130232 ) ;
131233 }
132234
133- // TODO: test the non-known case
235+ #[ test]
236+ fn resolve_redirect_target_relative_and_absolute ( ) {
237+ let base = "wss://cable.example.com/cable/contact/abc" ;
238+ assert_eq ! (
239+ resolve_redirect_target( base, "/cable/contact/v2/abc" ) . unwrap( ) ,
240+ "wss://cable.example.com/cable/contact/v2/abc"
241+ ) ;
242+ assert_eq ! (
243+ resolve_redirect_target( base, "wss://cable.example.net/cable/contact/xyz" ) . unwrap( ) ,
244+ "wss://cable.example.net/cable/contact/xyz"
245+ ) ;
246+ }
247+
248+ #[ test]
249+ fn build_tunnel_request_reattaches_headers_for_known_device ( ) {
250+ let connection_type = known_device_connection_type ( vec ! [ 4u8 ; 65 ] ) ;
251+ let request = build_tunnel_request (
252+ "wss://cable.example.com/cable/contact/abc" ,
253+ & connection_type,
254+ )
255+ . unwrap ( ) ;
256+ assert_eq ! (
257+ request
258+ . headers( )
259+ . get( "Sec-WebSocket-Protocol" )
260+ . unwrap( )
261+ . to_str( )
262+ . unwrap( ) ,
263+ "fido.cable"
264+ ) ;
265+ assert ! ( request. headers( ) . get( "X-caBLE-Client-Payload" ) . is_some( ) ) ;
266+ }
267+
268+ #[ test]
269+ fn build_tunnel_request_omits_payload_for_qr_code ( ) {
270+ let connection_type = qr_connection_type ( ) ;
271+ let request = build_tunnel_request (
272+ "wss://cable.example.com/cable/connect/aabbcc/0011" ,
273+ & connection_type,
274+ )
275+ . unwrap ( ) ;
276+ assert_eq ! (
277+ request
278+ . headers( )
279+ . get( "Sec-WebSocket-Protocol" )
280+ . unwrap( )
281+ . to_str( )
282+ . unwrap( ) ,
283+ "fido.cable"
284+ ) ;
285+ assert ! ( request. headers( ) . get( "X-caBLE-Client-Payload" ) . is_none( ) ) ;
286+ }
287+
288+ #[ test]
289+ fn gone_forgets_known_device ( ) {
290+ let public_key = vec ! [ 7u8 ; 65 ] ;
291+ let connection_type = known_device_connection_type ( public_key. clone ( ) ) ;
292+ assert_eq ! (
293+ known_device_id_to_forget(
294+ & TransportError :: CableTunnel ( CableTunnelError :: Gone ) ,
295+ & connection_type
296+ ) ,
297+ Some ( hex:: encode( & public_key) )
298+ ) ;
299+ }
300+
301+ #[ test]
302+ fn gone_does_not_forget_qr_code ( ) {
303+ let connection_type = qr_connection_type ( ) ;
304+ assert_eq ! (
305+ known_device_id_to_forget(
306+ & TransportError :: CableTunnel ( CableTunnelError :: Gone ) ,
307+ & connection_type
308+ ) ,
309+ None
310+ ) ;
311+ }
312+
313+ #[ test]
314+ fn non_gone_error_does_not_forget_known_device ( ) {
315+ let connection_type = known_device_connection_type ( vec ! [ 7u8 ; 65 ] ) ;
316+ assert_eq ! (
317+ known_device_id_to_forget( & TransportError :: ConnectionFailed , & connection_type) ,
318+ None
319+ ) ;
320+ }
321+
322+ #[ test]
323+ fn gone_status_maps_to_distinct_error ( ) {
324+ assert_eq ! (
325+ tunnel_status_error( StatusCode :: GONE ) ,
326+ TransportError :: CableTunnel ( CableTunnelError :: Gone )
327+ ) ;
328+ assert_eq ! (
329+ tunnel_status_error( StatusCode :: BAD_GATEWAY ) ,
330+ TransportError :: CableTunnel ( CableTunnelError :: UnexpectedStatus ( 502 ) )
331+ ) ;
332+ }
134333}
0 commit comments