Skip to content

Commit 4a20c81

Browse files
authored
fix: connection lifecycle callbacks (#4935)
# Description of Changes This pull request improves the handling of connection lifecycle events in the Rust client SDK for SpacetimeDB, particularly distinguishing between connection failures and disconnections. It introduces a new `ConnectionLifecycle` state machine to track connection progress, ensures that the correct callback (`on_connect_error` or `on_disconnect`) is invoked based on the connection state. **Changes** * `ConnectionLifecycle` enum to track the connection state (`Connecting`, `Connected`, `Ended`) * Refactored error handling so that if a connection fails before establishment, the `on_connect_error` callback is invoked; if the connection fails after establishment, the `on_disconnect` callback is invoked. See `end_connection`. * Updated where disconnections are handled (`advance_one_message_blocking`, `advance_one_message_async`, and message processing) to use `finish_connection` * Improved handling of user-initiated disconnects during the connection process to avoid reporting them as connection errors and to ensure proper cleanup. # API and ABI breaking changes I guess maybe if people relied on the `on_connect_error` to actually fire the `on_disconnect` then this changes that behavior. # Expected complexity level and risk Maybe a 2? Seems pretty low risk but I'm still new to the codebase, please double check. This doesn't fix the websocket issues, that'll be for another day. I noticed websocket.rs has some places it just drops and the error isn't handled properly. We could technically surface that information and run our callbacks with more specific error messages. # Testing I had an agent build and run loads of tests for this but didn't commit those since it would have made the PR massive. I was planning on testing locally though to see if I could trigger a connection failure at some point, maybe via an invalid access token.
1 parent 8cd2936 commit 4a20c81

2 files changed

Lines changed: 95 additions & 53 deletions

File tree

docs/docs/00200-core-concepts/00600-clients/00500-rust-reference.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ impl DbConnectionBuilder {
147147
}
148148
```
149149

150-
Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when your connection fails.
151-
152-
A known bug in the SpacetimeDB Rust client SDK currently causes this callback never to be invoked. [`on_disconnect`](#callback-on_disconnect) callbacks are invoked instead.
150+
Chain a call to `.on_connect_error(callback)` to your builder to register a callback to run when a connection attempt fails asynchronously. Errors which prevent `build` from creating the connection are returned by `build` instead.
153151

154152
#### Callback `on_disconnect`
155153

@@ -162,7 +160,7 @@ impl DbConnectionBuilder {
162160
}
163161
```
164162

165-
Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error.
163+
Chain a call to `.on_disconnect(callback)` to your builder to register a callback to run when your established `DbConnection` disconnects from the remote database, either as a result of a call to [`disconnect`](#method-disconnect) or due to an error.
166164

167165
#### Method `with_token`
168166

sdks/rust/src/db_connection.rs

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,25 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
137137
fn process_message(&self, msg: ParsedMessage<M>) -> crate::Result<()> {
138138
self.debug_log(|out| writeln!(out, "`process_message`: {msg:?}"));
139139
match msg {
140-
// Error: treat this as an erroneous disconnect.
141-
ParsedMessage::Error(e) => {
142-
let disconnect_ctx = self.make_event_ctx(Some(e.clone()));
143-
self.invoke_disconnected(&disconnect_ctx);
144-
Err(e)
145-
}
140+
// Error: route as a connection error if we never finished connecting,
141+
// otherwise treat it as an erroneous disconnect.
142+
ParsedMessage::Error(e) => Err(self.end_connection(Some(e))),
146143

147144
// Initial `IdentityToken` message:
148145
// confirm that the received identity and connection ID are what we expect,
149-
// store them,
150-
// then invoke the on_connect callback.
146+
// store them, then invoke the on_connect callback.
151147
ParsedMessage::IdentityToken(identity, token, conn_id) => {
148+
let on_connect = {
149+
let mut inner = self.inner.lock().unwrap();
150+
match inner.connection_lifecycle {
151+
ConnectionLifecycle::Connecting => {
152+
inner.connection_lifecycle = ConnectionLifecycle::Connected;
153+
inner.on_connect.take()
154+
}
155+
ConnectionLifecycle::Connected => None,
156+
ConnectionLifecycle::Ended => return Ok(()),
157+
}
158+
};
152159
{
153160
// Don't hold the `self.identity` lock while running callbacks.
154161
// Callbacks can (will) call [`DbContext::identity`], which acquires that lock,
@@ -170,8 +177,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
170177
}
171178
*conn_id_store = Some(conn_id);
172179
}
173-
let mut inner = self.inner.lock().unwrap();
174-
if let Some(on_connect) = inner.on_connect.take() {
180+
if let Some(on_connect) = on_connect {
175181
let ctx = <M::DbConnection as DbConnection>::new(self.clone());
176182
on_connect(&ctx, identity, &token);
177183
}
@@ -306,23 +312,47 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
306312
applied_diff.invoke_row_callbacks(&row_event_ctx, &mut inner.db_callbacks);
307313
}
308314

309-
/// Invoke the on-disconnect callback, and mark [`Self::is_active`] false.
310-
fn invoke_disconnected(&self, ctx: &M::ErrorContext) {
315+
/// Mark the connection lifecycle as ended, route the terminal event to the
316+
/// appropriate connection callback, and mark [`Self::is_active`] false.
317+
///
318+
/// Returns the terminal error that should be returned from `advance_*` methods.
319+
fn end_connection(&self, callback_error: Option<crate::Error>) -> crate::Error {
311320
let mut inner = self.inner.lock().unwrap();
312-
// When we disconnect, we first call the on_disconnect method,
313-
// then we call the `on_error` method for all subscriptions.
314-
// We don't change the client cache at all.
321+
let return_error = callback_error.clone().unwrap_or(crate::Error::Disconnected);
322+
323+
let lifecycle = inner.connection_lifecycle;
324+
if lifecycle == ConnectionLifecycle::Ended {
325+
return return_error;
326+
}
327+
inner.connection_lifecycle = ConnectionLifecycle::Ended;
315328

316329
// Set `send_chan` to `None`, since `Self::is_active` checks that.
317330
*self.send_chan.lock().unwrap() = None;
318331

319-
// Grap the `on_disconnect` callback and invoke it.
320-
if let Some(disconnect_callback) = inner.on_disconnect.take() {
321-
disconnect_callback(ctx, ctx.event().clone());
322-
}
332+
match lifecycle {
333+
ConnectionLifecycle::Connecting => {
334+
let callback_error = callback_error.unwrap_or_else(|| crate::Error::FailedToConnect {
335+
source: InternalError::new("Connection closed before receiving the initial connection message"),
336+
});
337+
let ctx: M::ErrorContext = self.make_event_ctx(Some(callback_error.clone()));
338+
if let Some(connect_error_callback) = inner.on_connect_error.take() {
339+
connect_error_callback(&ctx, callback_error.clone());
340+
}
341+
callback_error
342+
}
343+
ConnectionLifecycle::Connected => {
344+
let ctx: M::ErrorContext = self.make_event_ctx(callback_error.clone());
345+
if let Some(disconnect_callback) = inner.on_disconnect.take() {
346+
disconnect_callback(&ctx, callback_error.clone());
347+
}
348+
349+
// Call the `on_disconnect` method for all subscriptions.
350+
inner.subscriptions.on_disconnect(&ctx);
323351

324-
// Call the `on_disconnect` method for all subscriptions.
325-
inner.subscriptions.on_disconnect(ctx);
352+
return_error
353+
}
354+
ConnectionLifecycle::Ended => return_error,
355+
}
326356
}
327357

328358
fn make_event_ctx<E, Ctx: AbstractEventContext<Module = M, Event = E>>(&self, event: E) -> Ctx {
@@ -447,10 +477,19 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
447477

448478
// Disconnect: close the connection.
449479
PendingMutation::Disconnect => {
480+
{
481+
let mut inner = self.inner.lock().unwrap();
482+
if inner.connection_lifecycle == ConnectionLifecycle::Connecting {
483+
// If the user cancels before the initial connection finishes,
484+
// don't report that as a connection error.
485+
inner.connection_lifecycle = ConnectionLifecycle::Ended;
486+
}
487+
}
450488
// Set `send_chan` to `None`, since `Self::is_active` checks that.
451489
// This will close the WebSocket loop in websocket.rs,
452490
// sending a close frame to the server,
453-
// eventually resulting in disconnect callbacks being called.
491+
// eventually resulting in disconnect callbacks being called
492+
// if the initial connection had completed.
454493
*self.send_chan.lock().unwrap() = None;
455494
}
456495

@@ -540,11 +579,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
540579
// `Stream::poll_next`. No comment on whether this is a good mental
541580
// model or not.
542581
let res = match get_lock_sync(&self.recv).try_next() {
543-
Ok(None) => {
544-
let disconnect_ctx = self.make_event_ctx(None);
545-
self.invoke_disconnected(&disconnect_ctx);
546-
Err(crate::Error::Disconnected)
547-
}
582+
Ok(None) => Err(self.end_connection(None)),
548583
Err(_) => Ok(false),
549584
Ok(Some(msg)) => self.process_message(msg).map(|_| true),
550585
};
@@ -599,11 +634,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
599634
pub fn advance_one_message_blocking(&self) -> crate::Result<()> {
600635
match self.runtime.block_on(self.get_message()) {
601636
Message::Local(pending) => self.apply_mutation(pending),
602-
Message::Ws(None) => {
603-
let disconnect_ctx = self.make_event_ctx(None);
604-
self.invoke_disconnected(&disconnect_ctx);
605-
Err(crate::Error::Disconnected)
606-
}
637+
Message::Ws(None) => Err(self.end_connection(None)),
607638
Message::Ws(Some(msg)) => self.process_message(msg),
608639
}
609640
}
@@ -614,11 +645,7 @@ impl<M: SpacetimeModule> DbContextImpl<M> {
614645
pub async fn advance_one_message_async(&self) -> crate::Result<()> {
615646
match self.get_message().await {
616647
Message::Local(pending) => self.apply_mutation(pending),
617-
Message::Ws(None) => {
618-
let disconnect_ctx = self.make_event_ctx(None);
619-
self.invoke_disconnected(&disconnect_ctx);
620-
Err(crate::Error::Disconnected)
621-
}
648+
Message::Ws(None) => Err(self.end_connection(None)),
622649
Message::Ws(Some(msg)) => self.process_message(msg),
623650
}
624651
}
@@ -784,6 +811,16 @@ type OnConnectErrorCallback<M> = Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorCo
784811
type OnDisconnectCallback<M> =
785812
Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorContext, Option<crate::Error>) + Send + 'static>;
786813

814+
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
815+
enum ConnectionLifecycle {
816+
/// Waiting for the server's initial connection message.
817+
Connecting,
818+
/// The server has sent the initial connection message.
819+
Connected,
820+
/// The connection has already reached a terminal lifecycle state.
821+
Ended,
822+
}
823+
787824
/// All the stuff in a [`DbContextImpl`] which can safely be locked while invoking callbacks.
788825
pub(crate) struct DbContextImplInner<M: SpacetimeModule> {
789826
/// `Some` if not within the context of an outer runtime. The `Runtime` must
@@ -796,9 +833,8 @@ pub(crate) struct DbContextImplInner<M: SpacetimeModule> {
796833
reducer_callbacks: ReducerCallbacks<M>,
797834
pub(crate) subscriptions: SubscriptionManager<M>,
798835

836+
connection_lifecycle: ConnectionLifecycle,
799837
on_connect: Option<OnConnectCallback<M>>,
800-
#[allow(unused)]
801-
// TODO: Make use of this to handle `ParsedMessage::Error` before receiving `IdentityToken`.
802838
on_connect_error: Option<OnConnectErrorCallback<M>>,
803839
on_disconnect: Option<OnDisconnectCallback<M>>,
804840

@@ -1040,9 +1076,10 @@ but you must call one of them, or else the connection will never progress.
10401076
/// If this method is not invoked, or `None` is supplied,
10411077
/// the SpacetimeDB host will generate a new anonymous `Identity`.
10421078
///
1043-
/// If the passed token is invalid or rejected by the host,
1044-
/// the connection will fail asynchrnonously.
1045-
// FIXME: currently this causes `disconnect` to be called rather than `on_connect_error`.
1079+
/// If the token is rejected before a connection context is created, [`Self::build`]
1080+
/// returns an error. If the host reports the rejection after the WebSocket is
1081+
/// established but before the initial connection message, [`Self::on_connect_error`]
1082+
/// is invoked.
10461083
pub fn with_token(mut self, token: Option<impl Into<String>>) -> Self {
10471084
self.token = token.map(|token| token.into());
10481085
self
@@ -1095,9 +1132,10 @@ but you must call one of them, or else the connection will never progress.
10951132
self
10961133
}
10971134

1098-
/// Register a callback to run when the connection is successfully initiated.
1135+
/// Register a callback to run when the connection is successfully established.
10991136
///
1100-
/// The callback will receive three arguments:
1137+
/// The connection is established after the initial connection message is
1138+
/// received from the host. The callback will receive three arguments:
11011139
/// - The `DbConnection` which has successfully connected.
11021140
/// - The `Identity` of the successful connection.
11031141
/// - The private access token which can be used to later re-authenticate as the same `Identity`.
@@ -1116,9 +1154,11 @@ Instead of registering multiple `on_connect` callbacks, register a single callba
11161154
self
11171155
}
11181156

1119-
/// Register a callback to run when the connection fails asynchronously,
1120-
/// e.g. due to invalid credentials.
1121-
// FIXME: currently never called; `on_disconnect` is called instead.
1157+
/// Register a callback to run when a connection attempt fails asynchronously.
1158+
///
1159+
/// This callback is invoked only before the initial connection message is
1160+
/// received from the host. Errors which prevent [`Self::build`] from creating
1161+
/// a connection are returned by [`Self::build`] instead.
11221162
pub fn on_connect_error(mut self, callback: impl FnOnce(&M::ErrorContext, crate::Error) + Send + 'static) -> Self {
11231163
if self.on_connect_error.is_some() {
11241164
panic!(
@@ -1132,8 +1172,11 @@ Instead of registering multiple `on_connect_error` callbacks, register a single
11321172
self
11331173
}
11341174

1135-
/// Register a callback to run when the connection is closed.
1136-
// FIXME: currently also called when the connection fails asynchronously, instead of `on_connect_error`.
1175+
/// Register a callback to run when an established connection is closed.
1176+
///
1177+
/// The connection is established after the initial connection message is
1178+
/// received from the host. Connection failures before that point invoke
1179+
/// [`Self::on_connect_error`] instead.
11371180
pub fn on_disconnect(
11381181
mut self,
11391182
callback: impl FnOnce(&M::ErrorContext, Option<crate::Error>) + Send + 'static,
@@ -1166,6 +1209,7 @@ fn build_db_ctx_inner<M: SpacetimeModule>(
11661209
reducer_callbacks: ReducerCallbacks::default(),
11671210
subscriptions: SubscriptionManager::default(),
11681211

1212+
connection_lifecycle: ConnectionLifecycle::Connecting,
11691213
on_connect: on_connect_cb,
11701214
on_connect_error: on_connect_error_cb,
11711215
on_disconnect: on_disconnect_cb,

0 commit comments

Comments
 (0)