Skip to content

Commit b9e4596

Browse files
committed
fixup! fixup! cli: Close the websocket connection gracefully
1 parent c9ab2a9 commit b9e4596

1 file changed

Lines changed: 67 additions & 24 deletions

File tree

crates/cli/src/subcommands/subscribe.rs

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use spacetimedb_data_structures::map::HashMap;
99
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
1010
use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
1111
use spacetimedb_lib::ser::serde::SerializeWrapper;
12+
use std::io;
1213
use std::time::Duration;
14+
use thiserror::Error;
1315
use tokio::io::AsyncWriteExt;
1416
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
1517
use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};
@@ -181,18 +183,45 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
181183

182184
// Close the connection gracefully, unless it's a websocket error,
183185
// in which case the connection is most likely already unusable.
184-
if !res.as_ref().is_err_and(|e| e.downcast_ref::<WsError>().is_some()) {
186+
if !matches!(res, Err(Error::Subscribe { .. } | Error::Websocket { .. })) {
185187
// Ignore errors here, we're going to drop the connection anyways.
186188
let _ = ws.close(None).await;
187189
}
188190

189-
res
191+
res.map_err(Into::into)
192+
}
193+
194+
#[derive(Debug, Error)]
195+
enum Error {
196+
#[error("error sending subscription queries")]
197+
Subscribe {
198+
#[source]
199+
source: WsError,
200+
},
201+
#[error("protocol error: {details}")]
202+
Protocol { details: &'static str },
203+
#[error("websocket error: {source}")]
204+
Websocket {
205+
#[source]
206+
source: WsError,
207+
},
208+
#[error("encountered failed transaction: {reason}")]
209+
TransactionFailure { reason: Box<str> },
210+
#[error("error formatting response: {source:#}")]
211+
Reformat {
212+
#[source]
213+
source: anyhow::Error,
214+
},
215+
#[error(transparent)]
216+
Serde(#[from] serde_json::Error),
217+
#[error(transparent)]
218+
Io(#[from] io::Error),
190219
}
191220

192221
/// Send the subscribe message.
193-
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), S::Error>
222+
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), Error>
194223
where
195-
S: Sink<WsMessage> + Unpin,
224+
S: Sink<WsMessage, Error = WsError> + Unpin,
196225
{
197226
let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
198227
ws::Subscribe {
@@ -201,35 +230,39 @@ where
201230
},
202231
)))
203232
.unwrap();
204-
ws.send(msg.into()).await
233+
ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source })
205234
}
206235

207236
/// Await the initial [`ServerMessage::SubscriptionUpdate`].
208237
/// If `module_def` is `Some`, print a JSON representation to stdout.
209-
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()>
238+
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error>
210239
where
211-
S: TryStream<Ok = WsMessage> + Unpin,
212-
S::Error: std::error::Error + Send + Sync + 'static,
240+
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
213241
{
214242
const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
215243

216-
while let Some(msg) = ws.try_next().await? {
244+
while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? {
217245
let Some(msg) = parse_msg_json(&msg) else { continue };
218246
match msg {
219247
ws::ServerMessage::InitialSubscription(sub) => {
220248
if let Some(module_def) = module_def {
221-
let formatted = reformat_update(&sub.database_update, module_def)?;
222-
let output = serde_json::to_string(&formatted)? + "\n";
249+
let output = format_output_json(&sub.database_update, module_def)?;
223250
tokio::io::stdout().write_all(output.as_bytes()).await?
224251
}
225252
break;
226253
}
227-
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status {
228-
ws::UpdateStatus::Failed(msg) => msg,
229-
_ => RECV_TX_UPDATE.into(),
230-
}),
254+
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
255+
return Err(match status {
256+
ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg },
257+
_ => Error::Protocol {
258+
details: RECV_TX_UPDATE,
259+
},
260+
})
261+
}
231262
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
232-
anyhow::bail!(RECV_TX_UPDATE)
263+
return Err(Error::Protocol {
264+
details: RECV_TX_UPDATE,
265+
})
233266
}
234267
_ => continue,
235268
}
@@ -240,37 +273,47 @@ where
240273

241274
/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
242275
/// If `num` is `None`, keep going indefinitely.
243-
async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> anyhow::Result<()>
276+
async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> Result<(), Error>
244277
where
245-
S: TryStream<Ok = WsMessage> + Unpin,
246-
S::Error: std::error::Error + Send + Sync + 'static,
278+
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
247279
{
248280
let mut stdout = tokio::io::stdout();
249281
let mut num_received = 0;
250282
loop {
251283
if num.is_some_and(|n| num_received >= n) {
252-
break Ok(());
284+
return Ok(());
253285
}
254-
let Some(msg) = ws.try_next().await? else {
286+
let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else {
255287
eprintln!("disconnected by server");
256-
break Err(WsError::ConnectionClosed.into());
288+
return Err(Error::Websocket {
289+
source: WsError::ConnectionClosed,
290+
});
257291
};
258292

259293
let Some(msg) = parse_msg_json(&msg) else { continue };
260294
match msg {
261295
ws::ServerMessage::InitialSubscription(_) => {
262-
anyhow::bail!("protocol error: received a second initial subscription update")
296+
return Err(Error::Protocol {
297+
details: "received a second initial subscription update",
298+
})
263299
}
264300
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
265301
| ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
266302
status: ws::UpdateStatus::Committed(update),
267303
..
268304
}) => {
269-
let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n";
305+
let output = format_output_json(&update, module_def)?;
270306
stdout.write_all(output.as_bytes()).await?;
271307
num_received += 1;
272308
}
273309
_ => continue,
274310
}
275311
}
276312
}
313+
314+
fn format_output_json(msg: &ws::DatabaseUpdate<JsonFormat>, schema: &RawModuleDefV9) -> Result<String, Error> {
315+
let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?;
316+
let output = serde_json::to_string(&formatted)? + "\n";
317+
318+
Ok(output)
319+
}

0 commit comments

Comments
 (0)