Skip to content

Commit bdac6ba

Browse files
authored
Merge branch 'master' into jdetter/login-and-resume-gameplay
2 parents c4675cd + 2227e0f commit bdac6ba

29 files changed

Lines changed: 357 additions & 109 deletions

File tree

crates/cli/src/subcommands/subscribe.rs

Lines changed: 101 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ use spacetimedb_data_structures::map::HashMap;
99
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
1010
use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
1111
use spacetimedb_lib::ser::serde::SerializeWrapper;
12+
use std::io;
1213
use std::time::Duration;
14+
use thiserror::Error;
1315
use tokio::io::AsyncWriteExt;
1416
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15-
use tokio_tungstenite::tungstenite::Message as WsMessage;
17+
use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};
1618

1719
use crate::api::ClientApi;
1820
use crate::common_args;
@@ -155,35 +157,88 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
155157
if let Some(auth_header) = api.con.auth_header.to_header() {
156158
req.headers_mut().insert(header::AUTHORIZATION, auth_header);
157159
}
158-
let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
160+
let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?;
159161

160162
let task = async {
161163
subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
162164
await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
163165
consume_transaction_updates(&mut ws, num, &module_def).await
164166
};
165167

166-
let needs_shutdown = if let Some(timeout) = timeout {
168+
let res = if let Some(timeout) = timeout {
167169
let timeout = Duration::from_secs(timeout.into());
168170
match tokio::time::timeout(timeout, task).await {
169-
Ok(res) => res?,
170-
Err(_elapsed) => true,
171+
Ok(res) => res,
172+
Err(_elapsed) => {
173+
eprintln!("timed out after {}s", timeout.as_secs());
174+
Ok(())
175+
}
171176
}
172177
} else {
173-
task.await?
178+
task.await
174179
};
175180

176-
if needs_shutdown {
177-
ws.close(None).await?;
178-
}
181+
// Close the connection gracefully.
182+
// This will return an error if the server already closed,
183+
// or the connection is in a bad state.
184+
// The error (if any) relevant to the user is already stored in `res`,
185+
// so we can ignore errors here -- graceful close is basically a
186+
// courtesy to the server.
187+
let _ = ws.close(None).await;
188+
// The server closing the connection is not considered an error,
189+
// but any other error is.
190+
res.or_else(|e| {
191+
if e.is_server_closed_connection() {
192+
Ok(())
193+
} else {
194+
Err(e)
195+
}
196+
})
197+
.map_err(anyhow::Error::from)
198+
}
179199

180-
Ok(())
200+
#[derive(Debug, Error)]
201+
enum Error {
202+
#[error("error sending subscription queries")]
203+
Subscribe {
204+
#[source]
205+
source: WsError,
206+
},
207+
#[error("protocol error: {details}")]
208+
Protocol { details: &'static str },
209+
#[error("websocket error: {source}")]
210+
Websocket {
211+
#[source]
212+
source: WsError,
213+
},
214+
#[error("encountered failed transaction: {reason}")]
215+
TransactionFailure { reason: Box<str> },
216+
#[error("error formatting response: {source:#}")]
217+
Reformat {
218+
#[source]
219+
source: anyhow::Error,
220+
},
221+
#[error(transparent)]
222+
Serde(#[from] serde_json::Error),
223+
#[error(transparent)]
224+
Io(#[from] io::Error),
225+
}
226+
227+
impl Error {
228+
fn is_server_closed_connection(&self) -> bool {
229+
matches!(
230+
self,
231+
Self::Websocket {
232+
source: WsError::ConnectionClosed
233+
}
234+
)
235+
}
181236
}
182237

183238
/// Send the subscribe message.
184-
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), S::Error>
239+
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), Error>
185240
where
186-
S: Sink<WsMessage> + Unpin,
241+
S: Sink<WsMessage, Error = WsError> + Unpin,
187242
{
188243
let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
189244
ws::Subscribe {
@@ -192,35 +247,39 @@ where
192247
},
193248
)))
194249
.unwrap();
195-
ws.send(msg.into()).await
250+
ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source })
196251
}
197252

198253
/// Await the initial [`ServerMessage::SubscriptionUpdate`].
199254
/// If `module_def` is `Some`, print a JSON representation to stdout.
200-
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()>
255+
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error>
201256
where
202-
S: TryStream<Ok = WsMessage> + Unpin,
203-
S::Error: std::error::Error + Send + Sync + 'static,
257+
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
204258
{
205259
const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
206260

207-
while let Some(msg) = ws.try_next().await? {
261+
while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? {
208262
let Some(msg) = parse_msg_json(&msg) else { continue };
209263
match msg {
210264
ws::ServerMessage::InitialSubscription(sub) => {
211265
if let Some(module_def) = module_def {
212-
let formatted = reformat_update(&sub.database_update, module_def)?;
213-
let output = serde_json::to_string(&formatted)? + "\n";
266+
let output = format_output_json(&sub.database_update, module_def)?;
214267
tokio::io::stdout().write_all(output.as_bytes()).await?
215268
}
216269
break;
217270
}
218-
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status {
219-
ws::UpdateStatus::Failed(msg) => msg,
220-
_ => RECV_TX_UPDATE.into(),
221-
}),
271+
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
272+
return Err(match status {
273+
ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg },
274+
_ => Error::Protocol {
275+
details: RECV_TX_UPDATE,
276+
},
277+
})
278+
}
222279
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
223-
anyhow::bail!(RECV_TX_UPDATE)
280+
return Err(Error::Protocol {
281+
details: RECV_TX_UPDATE,
282+
})
224283
}
225284
_ => continue,
226285
}
@@ -231,41 +290,47 @@ where
231290

232291
/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
233292
/// If `num` is `None`, keep going indefinitely.
234-
async fn consume_transaction_updates<S>(
235-
ws: &mut S,
236-
num: Option<u32>,
237-
module_def: &RawModuleDefV9,
238-
) -> anyhow::Result<bool>
293+
async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> Result<(), Error>
239294
where
240-
S: TryStream<Ok = WsMessage> + Unpin,
241-
S::Error: std::error::Error + Send + Sync + 'static,
295+
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
242296
{
243297
let mut stdout = tokio::io::stdout();
244298
let mut num_received = 0;
245299
loop {
246300
if num.is_some_and(|n| num_received >= n) {
247-
break Ok(true);
301+
return Ok(());
248302
}
249-
let Some(msg) = ws.try_next().await? else {
303+
let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else {
250304
eprintln!("disconnected by server");
251-
break Ok(false);
305+
return Err(Error::Websocket {
306+
source: WsError::ConnectionClosed,
307+
});
252308
};
253309

254310
let Some(msg) = parse_msg_json(&msg) else { continue };
255311
match msg {
256312
ws::ServerMessage::InitialSubscription(_) => {
257-
anyhow::bail!("protocol error: received a second initial subscription update")
313+
return Err(Error::Protocol {
314+
details: "received a second initial subscription update",
315+
})
258316
}
259317
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
260318
| ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
261319
status: ws::UpdateStatus::Committed(update),
262320
..
263321
}) => {
264-
let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n";
322+
let output = format_output_json(&update, module_def)?;
265323
stdout.write_all(output.as_bytes()).await?;
266324
num_received += 1;
267325
}
268326
_ => continue,
269327
}
270328
}
271329
}
330+
331+
fn format_output_json(msg: &ws::DatabaseUpdate<JsonFormat>, schema: &RawModuleDefV9) -> Result<String, Error> {
332+
let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?;
333+
let output = serde_json::to_string(&formatted)? + "\n";
334+
335+
Ok(output)
336+
}

crates/core/src/error.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ pub enum DatabaseError {
8282
DatabasedOpened(PathBuf, anyhow::Error),
8383
}
8484

85-
// FIXME: reduce type size
86-
#[expect(clippy::large_enum_variant)]
8785
#[derive(Error, Debug, EnumAsInner)]
8886
pub enum DBError {
8987
#[error("LibError: {0}")]

crates/core/src/host/module_host.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,6 @@ pub enum InitDatabaseError {
539539
Other(anyhow::Error),
540540
}
541541

542-
// FIXME: reduce type size
543-
#[expect(clippy::large_enum_variant)]
544542
#[derive(thiserror::Error, Debug)]
545543
pub enum ClientConnectedError {
546544
#[error(transparent)]

crates/core/src/sql/execute.rs

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,11 @@ pub(crate) mod tests {
11261126
Ok(())
11271127
}
11281128

1129+
/// Test we are protected against stack overflows when:
1130+
/// 1. The query is too large (too many characters)
1131+
/// 2. The AST is too deep
1132+
///
1133+
/// Exercise the limit [`recursion::MAX_RECURSION_EXPR`]
11291134
#[test]
11301135
fn test_large_query_no_panic() -> ResultTest<()> {
11311136
let db = TestDB::durable()?;
@@ -1138,16 +1143,43 @@ pub(crate) mod tests {
11381143
)
11391144
.unwrap();
11401145

1141-
let mut query = "select * from test where ".to_string();
1142-
for x in 0..1_000 {
1143-
for y in 0..1_000 {
1144-
let fragment = format!("((x = {x}) and y = {y}) or");
1145-
query.push_str(&fragment);
1146+
let build_query = |total| {
1147+
let mut sql = "select * from test where ".to_string();
1148+
for x in 1..total {
1149+
let fragment = format!("x = {x} or ");
1150+
sql.push_str(&fragment.repeat((total - 1) as usize));
11461151
}
1152+
sql.push_str("(y = 0)");
1153+
sql
1154+
};
1155+
let run = |db: &RelationalDB, sep: char, sql_text: &str| {
1156+
run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
1157+
};
1158+
let sql = build_query(1_000);
1159+
assert_eq!(
1160+
run(&db, ':', &sql),
1161+
Err("SQL query exceeds maximum allowed length".to_string())
1162+
);
1163+
1164+
let sql = build_query(41); // This causes stack overflow without the limit
1165+
assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string()));
1166+
1167+
let sql = build_query(40); // The max we can with the current limit
1168+
assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic");
1169+
1170+
// Check no overflow with lot of joins
1171+
let mut sql = "SELECT test.* FROM test ".to_string();
1172+
// We could push up to 700 joins without overflow as long we don't have any conditions,
1173+
// but here execution become too slow.
1174+
// TODO: Move this test to the `Plan`
1175+
for i in 0..200 {
1176+
sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y "));
11471177
}
1148-
query.push_str("((x = 1000) and (y = 1000))");
11491178

1150-
assert!(run_for_testing(&db, &query).is_err());
1179+
assert!(
1180+
run(&db, ',', &sql).is_ok(),
1181+
"Query with many joins and conditions should not overflow"
1182+
);
11511183
Ok(())
11521184
}
11531185

crates/expr/src/errors.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ pub struct DuplicateName(pub String);
122122
#[error("`filter!` does not support column projections; Must return table rows")]
123123
pub struct FilterReturnType;
124124

125-
// FIXME: reduce type size
126-
#[expect(clippy::large_enum_variant)]
127125
#[derive(Error, Debug)]
128126
pub enum TypingError {
129127
#[error(transparent)]

crates/expr/src/lib.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
1919
use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
2020
use spacetimedb_schema::schema::ColumnSchema;
2121
use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22+
use spacetimedb_sql_parser::parser::recursion;
2223

2324
pub mod check;
2425
pub mod errors;
@@ -78,8 +79,14 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T
7879
}
7980
}
8081

81-
/// Type check and lower a [SqlExpr] into a logical [Expr].
82-
pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
82+
// These types determine the size of each stack frame during type checking.
83+
// Changing their sizes will require updating the recursion limit to avoid stack overflows.
84+
const _: () = assert!(size_of::<TypingResult<Expr>>() == 64);
85+
const _: () = assert!(size_of::<SqlExpr>() == 40);
86+
87+
fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult<Expr> {
88+
recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?;
89+
8390
match (expr, expected) {
8491
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
8592
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
@@ -117,21 +124,21 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra
117124
}))
118125
}
119126
(SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
120-
let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
121-
let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
127+
let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?;
128+
let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?;
122129
Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
123130
}
124131
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
125-
let b = type_expr(vars, *b, None)?;
126-
let a = type_expr(vars, *a, Some(b.ty()))?;
132+
let b = _type_expr(vars, *b, None, depth + 1)?;
133+
let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?;
127134
if !op_supports_type(op, a.ty()) {
128135
return Err(InvalidOp::new(op, a.ty()).into());
129136
}
130137
Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
131138
}
132139
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
133-
let a = type_expr(vars, *a, None)?;
134-
let b = type_expr(vars, *b, Some(a.ty()))?;
140+
let a = _type_expr(vars, *a, None, depth + 1)?;
141+
let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?;
135142
if !op_supports_type(op, a.ty()) {
136143
return Err(InvalidOp::new(op, a.ty()).into());
137144
}
@@ -144,6 +151,11 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra
144151
}
145152
}
146153

154+
/// Type check and lower a [SqlExpr] into a logical [Expr].
155+
pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
156+
_type_expr(vars, expr, expected, 0)
157+
}
158+
147159
/// Is this type compatible with this binary operator?
148160
fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
149161
t.is_bool()

0 commit comments

Comments
 (0)