Skip to content

Commit 46d94b4

Browse files
committed
Add token verification for the wasm sdk websocket connection
1 parent 33668b0 commit 46d94b4

4 files changed

Lines changed: 114 additions & 4 deletions

File tree

Cargo.lock

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdks/rust/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ default = []
1313
web = [
1414
"dep:getrandom",
1515
"dep:gloo-console",
16+
"dep:gloo-net",
1617
"dep:gloo-storage",
18+
"dep:js-sys",
1719
"dep:rustls-pki-types",
1820
"dep:tokio-tungstenite-wasm",
1921
"dep:wasm-bindgen",
@@ -44,7 +46,9 @@ rand.workspace = true
4446

4547
getrandom = { version = "0.3.2", features = ["wasm_js"], optional = true }
4648
gloo-console = { version = "0.3.0", optional = true }
49+
gloo-net = { version = "0.6.0", optional = true }
4750
gloo-storage = { version = "0.3.0", optional = true }
51+
js-sys = { version = "0.3", optional = true }
4852
rustls-pki-types = { version = "1.12.0", features = ["web"], optional = true }
4953
tokio-tungstenite-wasm = { version = "0.6.0", optional = true }
5054
wasm-bindgen = { version = "0.2.100", optional = true }

sdks/rust/src/db_connection.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,6 @@ but you must call one of them, or else the connection will never progress.
11141114
/// If the passed token is invalid or rejected by the host,
11151115
/// the connection will fail asynchrnonously.
11161116
// FIXME: currently this causes `disconnect` to be called rather than `on_connect_error`.
1117-
#[cfg(not(target_arch = "wasm32"))]
11181117
pub fn with_token(mut self, token: Option<impl Into<String>>) -> Self {
11191118
self.token = token.map(|token| token.into());
11201119
self

sdks/rust/src/websocket.rs

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ pub enum WsError {
9696

9797
#[error("Unrecognized compression scheme: {scheme:#x}")]
9898
UnknownCompressionScheme { scheme: u8 },
99+
100+
#[cfg(feature = "web")]
101+
#[error("Token verification error: {0}")]
102+
TokenVerification(String),
99103
}
100104

101105
pub(crate) struct WsConnection {
@@ -132,7 +136,29 @@ pub(crate) struct WsParams {
132136
pub confirmed: Option<bool>,
133137
}
134138

139+
#[cfg(not(feature = "web"))]
135140
fn make_uri(host: Uri, db_name: &str, connection_id: Option<ConnectionId>, params: WsParams) -> Result<Uri, UriError> {
141+
make_uri_impl(host, db_name, connection_id, params, None)
142+
}
143+
144+
#[cfg(feature = "web")]
145+
fn make_uri(
146+
host: Uri,
147+
db_name: &str,
148+
connection_id: Option<ConnectionId>,
149+
params: WsParams,
150+
token: Option<&str>,
151+
) -> Result<Uri, UriError> {
152+
make_uri_impl(host, db_name, connection_id, params, token)
153+
}
154+
155+
fn make_uri_impl(
156+
host: Uri,
157+
db_name: &str,
158+
connection_id: Option<ConnectionId>,
159+
params: WsParams,
160+
token: Option<&str>,
161+
) -> Result<Uri, UriError> {
136162
let mut parts = host.into_parts();
137163
let scheme = parse_scheme(parts.scheme.take())?;
138164
parts.scheme = Some(scheme);
@@ -181,6 +207,11 @@ fn make_uri(host: Uri, db_name: &str, connection_id: Option<ConnectionId>, param
181207
path.push_str(if confirmed { "true" } else { "false" });
182208
}
183209

210+
// Specify the `token` param if needed
211+
if let Some(token) = token {
212+
path.push_str(&format!("&token={token}"));
213+
}
214+
184215
parts.path_and_query = Some(path.parse().map_err(|source: InvalidUri| UriError::InvalidUri {
185216
source: Arc::new(source),
186217
})?);
@@ -232,10 +263,57 @@ fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>)
232263
}
233264
}
234265

266+
#[cfg(feature = "web")]
267+
async fn fetch_ws_token(host: &Uri, auth_token: &str) -> Result<String, WsError> {
268+
use gloo_net::http::{Method, RequestBuilder};
269+
use js_sys::{Reflect, JSON};
270+
use wasm_bindgen::{JsCast, JsValue};
271+
272+
let url = format!("{}v1/identity/websocket-token", host);
273+
274+
// helpers to convert gloo_net::Error or JsValue into WsError::TokenVerification
275+
let gloo_to_ws_err = |e: gloo_net::Error| match e {
276+
gloo_net::Error::JsError(js_err) => WsError::TokenVerification(js_err.message.into()),
277+
gloo_net::Error::SerdeError(e) => WsError::TokenVerification(e.to_string()),
278+
gloo_net::Error::GlooError(msg) => WsError::TokenVerification(msg),
279+
};
280+
let js_to_ws_err = |e: JsValue| {
281+
if let Some(err) = e.dyn_ref::<js_sys::Error>() {
282+
WsError::TokenVerification(err.message().into())
283+
} else if let Some(s) = e.as_string() {
284+
WsError::TokenVerification(s)
285+
} else {
286+
WsError::TokenVerification(format!("{:?}", e))
287+
}
288+
};
289+
290+
let res = RequestBuilder::new(&url)
291+
.method(Method::POST)
292+
.header("Authorization", &format!("Bearer {auth_token}"))
293+
.send()
294+
.await
295+
.map_err(gloo_to_ws_err)?;
296+
297+
if !res.ok() {
298+
return Err(WsError::TokenVerification(format!(
299+
"HTTP error: {} {}",
300+
res.status(),
301+
res.status_text()
302+
)));
303+
}
304+
305+
let body = res.text().await.map_err(gloo_to_ws_err)?;
306+
let json = JSON::parse(&body).map_err(js_to_ws_err)?;
307+
let token_js = Reflect::get(&json, &JsValue::from_str("token")).map_err(js_to_ws_err)?;
308+
token_js
309+
.as_string()
310+
.ok_or_else(|| WsError::TokenVerification("`token` parsing failed".into()))
311+
}
312+
235313
/// If `res` evaluates to `Err(e)`, log a warning in the form `"{}: {:?}", $cause, e`.
236314
///
237315
/// Could be trivially written as a function, but macro-ifying it preserves the source location of the log.
238-
#[cfg(not(target_arch = "wasm32"))]
316+
#[cfg(not(feature = "web"))]
239317
macro_rules! maybe_log_error {
240318
($cause:expr, $res:expr) => {
241319
if let Err(e) = $res {
@@ -281,11 +359,17 @@ impl WsConnection {
281359
pub(crate) async fn connect(
282360
host: Uri,
283361
db_name: &str,
284-
_token: Option<&str>,
362+
token: Option<&str>,
285363
connection_id: Option<ConnectionId>,
286364
params: WsParams,
287365
) -> Result<Self, WsError> {
288-
let uri = make_uri(host, db_name, connection_id, params)?;
366+
let token = if let Some(auth_token) = token {
367+
Some(fetch_ws_token(&host, auth_token).await?)
368+
} else {
369+
None
370+
};
371+
372+
let uri = make_uri(host, db_name, connection_id, params, token.as_deref())?;
289373
let sock = tokio_tungstenite_wasm::connect_with_protocols(&uri.to_string(), &[BIN_PROTOCOL])
290374
.await
291375
.map_err(|source| WsError::Tungstenite {

0 commit comments

Comments
 (0)