@@ -96,6 +96,10 @@ pub enum WsError {
9696
9797 #[ error( "Unrecognized compression scheme: {scheme:#x}" ) ]
9898 UnknownCompressionScheme { scheme : u8 } ,
99+
100+ #[ cfg( feature = "web" ) ]
101+ #[ error( "Token verification error: {0}" ) ]
102+ TokenVerification ( String ) ,
99103}
100104
101105pub ( crate ) struct WsConnection {
@@ -132,7 +136,29 @@ pub(crate) struct WsParams {
132136 pub confirmed : Option < bool > ,
133137}
134138
139+ #[ cfg( not( feature = "web" ) ) ]
135140fn make_uri ( host : Uri , db_name : & str , connection_id : Option < ConnectionId > , params : WsParams ) -> Result < Uri , UriError > {
141+ make_uri_impl ( host, db_name, connection_id, params, None )
142+ }
143+
144+ #[ cfg( feature = "web" ) ]
145+ fn make_uri (
146+ host : Uri ,
147+ db_name : & str ,
148+ connection_id : Option < ConnectionId > ,
149+ params : WsParams ,
150+ token : Option < & str > ,
151+ ) -> Result < Uri , UriError > {
152+ make_uri_impl ( host, db_name, connection_id, params, token)
153+ }
154+
155+ fn make_uri_impl (
156+ host : Uri ,
157+ db_name : & str ,
158+ connection_id : Option < ConnectionId > ,
159+ params : WsParams ,
160+ token : Option < & str > ,
161+ ) -> Result < Uri , UriError > {
136162 let mut parts = host. into_parts ( ) ;
137163 let scheme = parse_scheme ( parts. scheme . take ( ) ) ?;
138164 parts. scheme = Some ( scheme) ;
@@ -181,6 +207,11 @@ fn make_uri(host: Uri, db_name: &str, connection_id: Option<ConnectionId>, param
181207 path. push_str ( if confirmed { "true" } else { "false" } ) ;
182208 }
183209
210+ // Specify the `token` param if needed
211+ if let Some ( token) = token {
212+ path. push_str ( & format ! ( "&token={token}" ) ) ;
213+ }
214+
184215 parts. path_and_query = Some ( path. parse ( ) . map_err ( |source : InvalidUri | UriError :: InvalidUri {
185216 source : Arc :: new ( source) ,
186217 } ) ?) ;
@@ -232,10 +263,57 @@ fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>)
232263 }
233264}
234265
266+ #[ cfg( feature = "web" ) ]
267+ async fn fetch_ws_token ( host : & Uri , auth_token : & str ) -> Result < String , WsError > {
268+ use gloo_net:: http:: { Method , RequestBuilder } ;
269+ use js_sys:: { Reflect , JSON } ;
270+ use wasm_bindgen:: { JsCast , JsValue } ;
271+
272+ let url = format ! ( "{}v1/identity/websocket-token" , host) ;
273+
274+ // helpers to convert gloo_net::Error or JsValue into WsError::TokenVerification
275+ let gloo_to_ws_err = |e : gloo_net:: Error | match e {
276+ gloo_net:: Error :: JsError ( js_err) => WsError :: TokenVerification ( js_err. message . into ( ) ) ,
277+ gloo_net:: Error :: SerdeError ( e) => WsError :: TokenVerification ( e. to_string ( ) ) ,
278+ gloo_net:: Error :: GlooError ( msg) => WsError :: TokenVerification ( msg) ,
279+ } ;
280+ let js_to_ws_err = |e : JsValue | {
281+ if let Some ( err) = e. dyn_ref :: < js_sys:: Error > ( ) {
282+ WsError :: TokenVerification ( err. message ( ) . into ( ) )
283+ } else if let Some ( s) = e. as_string ( ) {
284+ WsError :: TokenVerification ( s)
285+ } else {
286+ WsError :: TokenVerification ( format ! ( "{:?}" , e) )
287+ }
288+ } ;
289+
290+ let res = RequestBuilder :: new ( & url)
291+ . method ( Method :: POST )
292+ . header ( "Authorization" , & format ! ( "Bearer {auth_token}" ) )
293+ . send ( )
294+ . await
295+ . map_err ( gloo_to_ws_err) ?;
296+
297+ if !res. ok ( ) {
298+ return Err ( WsError :: TokenVerification ( format ! (
299+ "HTTP error: {} {}" ,
300+ res. status( ) ,
301+ res. status_text( )
302+ ) ) ) ;
303+ }
304+
305+ let body = res. text ( ) . await . map_err ( gloo_to_ws_err) ?;
306+ let json = JSON :: parse ( & body) . map_err ( js_to_ws_err) ?;
307+ let token_js = Reflect :: get ( & json, & JsValue :: from_str ( "token" ) ) . map_err ( js_to_ws_err) ?;
308+ token_js
309+ . as_string ( )
310+ . ok_or_else ( || WsError :: TokenVerification ( "`token` parsing failed" . into ( ) ) )
311+ }
312+
235313/// If `res` evaluates to `Err(e)`, log a warning in the form `"{}: {:?}", $cause, e`.
236314///
237315/// Could be trivially written as a function, but macro-ifying it preserves the source location of the log.
238- #[ cfg( not( target_arch = "wasm32 " ) ) ]
316+ #[ cfg( not( feature = "web " ) ) ]
239317macro_rules! maybe_log_error {
240318 ( $cause: expr, $res: expr) => {
241319 if let Err ( e) = $res {
@@ -281,11 +359,17 @@ impl WsConnection {
281359 pub ( crate ) async fn connect (
282360 host : Uri ,
283361 db_name : & str ,
284- _token : Option < & str > ,
362+ token : Option < & str > ,
285363 connection_id : Option < ConnectionId > ,
286364 params : WsParams ,
287365 ) -> Result < Self , WsError > {
288- let uri = make_uri ( host, db_name, connection_id, params) ?;
366+ let token = if let Some ( auth_token) = token {
367+ Some ( fetch_ws_token ( & host, auth_token) . await ?)
368+ } else {
369+ None
370+ } ;
371+
372+ let uri = make_uri ( host, db_name, connection_id, params, token. as_deref ( ) ) ?;
289373 let sock = tokio_tungstenite_wasm:: connect_with_protocols ( & uri. to_string ( ) , & [ BIN_PROTOCOL ] )
290374 . await
291375 . map_err ( |source| WsError :: Tungstenite {
0 commit comments