Skip to content

Commit 1dbeba3

Browse files
committed
Add token verification for the wasm sdk websocket connection
1 parent 7bfb2ab commit 1dbeba3

4 files changed

Lines changed: 113 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/sdk/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ default = []
1212
web = [
1313
"dep:getrandom",
1414
"dep:gloo-console",
15+
"dep:gloo-net",
1516
"dep:gloo-storage",
17+
"dep:js-sys",
1618
"dep:rustls-pki-types",
1719
"dep:tokio-tungstenite-wasm",
1820
"dep:wasm-bindgen",
@@ -42,7 +44,9 @@ rand.workspace = true
4244

4345
getrandom = { version = "0.3.2", features = ["wasm_js"], optional = true }
4446
gloo-console = { version = "0.3.0", optional = true }
47+
gloo-net = { version = "0.6.0", optional = true }
4548
gloo-storage = { version = "0.3.0", optional = true }
49+
js-sys = { version = "0.3", optional = true }
4650
rustls-pki-types = { version = "1.12.0", features = ["web"], optional = true }
4751
tokio-tungstenite-wasm = { version = "0.6.0", optional = true }
4852
wasm-bindgen = { version = "0.2.100", optional = true }

crates/sdk/src/db_connection.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,6 @@ but you must call one of them, or else the connection will never progress.
10131013
/// If the passed token is invalid or rejected by the host,
10141014
/// the connection will fail asynchrnonously.
10151015
// FIXME: currently this causes `disconnect` to be called rather than `on_connect_error`.
1016-
#[cfg(not(target_arch = "wasm32"))]
10171016
pub fn with_token(mut self, token: Option<impl Into<String>>) -> Self {
10181017
self.token = token.map(|token| token.into());
10191018
self

crates/sdk/src/websocket.rs

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ pub enum WsError {
9898

9999
#[error("Unrecognized compression scheme: {scheme:#x}")]
100100
UnknownCompressionScheme { scheme: u8 },
101+
102+
#[cfg(feature = "web")]
103+
#[error("Token verification error: {0}")]
104+
TokenVerification(String),
101105
}
102106

103107
pub(crate) struct WsConnection {
@@ -131,7 +135,29 @@ pub(crate) struct WsParams {
131135
pub light: bool,
132136
}
133137

138+
#[cfg(not(feature = "web"))]
134139
fn make_uri(host: Uri, db_name: &str, connection_id: ConnectionId, params: WsParams) -> Result<Uri, UriError> {
140+
make_uri_impl(host, db_name, connection_id, params, None)
141+
}
142+
143+
#[cfg(feature = "web")]
144+
fn make_uri(
145+
host: Uri,
146+
db_name: &str,
147+
connection_id: ConnectionId,
148+
params: WsParams,
149+
token: Option<&str>,
150+
) -> Result<Uri, UriError> {
151+
make_uri_impl(host, db_name, connection_id, params, token)
152+
}
153+
154+
fn make_uri_impl(
155+
host: Uri,
156+
db_name: &str,
157+
connection_id: ConnectionId,
158+
params: WsParams,
159+
token: Option<&str>,
160+
) -> Result<Uri, UriError> {
135161
let mut parts = host.into_parts();
136162
let scheme = parse_scheme(parts.scheme.take())?;
137163
parts.scheme = Some(scheme);
@@ -171,6 +197,11 @@ fn make_uri(host: Uri, db_name: &str, connection_id: ConnectionId, params: WsPar
171197
path.push_str("&light=true");
172198
}
173199

200+
// Specify the `token` param if needed
201+
if let Some(token) = token {
202+
path.push_str(&format!("&token={token}"));
203+
}
204+
174205
parts.path_and_query = Some(path.parse().map_err(|source: InvalidUri| UriError::InvalidUri {
175206
source: Arc::new(source),
176207
})?);
@@ -222,6 +253,53 @@ fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>)
222253
}
223254
}
224255

256+
#[cfg(feature = "web")]
257+
async fn fetch_ws_token(host: &Uri, auth_token: &str) -> Result<String, WsError> {
258+
use gloo_net::http::{Method, RequestBuilder};
259+
use js_sys::{Reflect, JSON};
260+
use wasm_bindgen::{JsCast, JsValue};
261+
262+
let url = format!("{}v1/identity/websocket-token", host);
263+
264+
// helpers to convert gloo_net::Error or JsValue into WsError::TokenVerification
265+
let gloo_to_ws_err = |e: gloo_net::Error| match e {
266+
gloo_net::Error::JsError(js_err) => WsError::TokenVerification(js_err.message.into()),
267+
gloo_net::Error::SerdeError(e) => WsError::TokenVerification(e.to_string()),
268+
gloo_net::Error::GlooError(msg) => WsError::TokenVerification(msg),
269+
};
270+
let js_to_ws_err = |e: JsValue| {
271+
if let Some(err) = e.dyn_ref::<js_sys::Error>() {
272+
WsError::TokenVerification(err.message().into())
273+
} else if let Some(s) = e.as_string() {
274+
WsError::TokenVerification(s)
275+
} else {
276+
WsError::TokenVerification(format!("{:?}", e))
277+
}
278+
};
279+
280+
let res = RequestBuilder::new(&url)
281+
.method(Method::POST)
282+
.header("Authorization", &format!("Bearer {auth_token}"))
283+
.send()
284+
.await
285+
.map_err(gloo_to_ws_err)?;
286+
287+
if !res.ok() {
288+
return Err(WsError::TokenVerification(format!(
289+
"HTTP error: {} {}",
290+
res.status(),
291+
res.status_text()
292+
)));
293+
}
294+
295+
let body = res.text().await.map_err(gloo_to_ws_err)?;
296+
let json = JSON::parse(&body).map_err(js_to_ws_err)?;
297+
let token_js = Reflect::get(&json, &JsValue::from_str("token")).map_err(js_to_ws_err)?;
298+
token_js
299+
.as_string()
300+
.ok_or_else(|| WsError::TokenVerification("`token` parsing failed".into()))
301+
}
302+
225303
impl WsConnection {
226304
#[cfg(not(feature = "web"))]
227305
pub(crate) async fn connect(
@@ -260,11 +338,17 @@ impl WsConnection {
260338
pub(crate) async fn connect(
261339
host: Uri,
262340
db_name: &str,
263-
_token: Option<&str>,
341+
token: Option<&str>,
264342
connection_id: ConnectionId,
265343
params: WsParams,
266344
) -> Result<Self, WsError> {
267-
let uri = make_uri(host, db_name, connection_id, params)?;
345+
let token = if let Some(auth_token) = token {
346+
Some(fetch_ws_token(&host, auth_token).await?)
347+
} else {
348+
None
349+
};
350+
351+
let uri = make_uri(host, db_name, connection_id, params, token.as_deref())?;
268352
let sock = tokio_tungstenite_wasm::connect_with_protocols(&uri.to_string(), &[BIN_PROTOCOL])
269353
.await
270354
.map_err(|source| WsError::Tungstenite {

0 commit comments

Comments
 (0)