Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,14 +683,17 @@ impl ActorContext {
);
conn.configure_hibernation(hibernation);
self.prepare_managed_conn(&conn);
self.insert_existing(conn.clone());

if let Err(error) = prepare_connection(&conn) {
self.remove_existing(conn.id());
return Err(error);
}

if let Err(error) = self.emit_connection_open(&conn, params, request).await {
self
.emit_connection_preflight(&conn, params.clone(), request.clone())
.await?;
self.insert_existing(conn.clone());

if let Err(error) = self.emit_connection_open(&conn, request).await {
self.remove_existing(conn.id());
return Err(error);
}
Expand Down Expand Up @@ -868,15 +871,13 @@ impl ActorContext {
async fn emit_connection_open(
&self,
conn: &ConnHandle,
params: Vec<u8>,
request: Option<Request>,
) -> Result<()> {
let config = self.connection_config();
let (reply_tx, reply_rx) = oneshot::channel();
self.try_send_actor_event(
ActorEvent::ConnectionOpen {
conn: conn.clone(),
params,
request,
reply: Reply::from(reply_tx),
},
Expand All @@ -889,6 +890,33 @@ impl ActorContext {
Ok(())
}

async fn emit_connection_preflight(
&self,
conn: &ConnHandle,
params: Vec<u8>,
request: Option<Request>,
) -> Result<()> {
let config = self.connection_config();
let timeout_duration = config
.on_before_connect_timeout
.saturating_add(config.create_conn_state_timeout);
let (reply_tx, reply_rx) = oneshot::channel();
self.try_send_actor_event(
ActorEvent::ConnectionPreflight {
conn: conn.clone(),
params,
request,
reply: Reply::from(reply_tx),
},
"connection_preflight",
)?;
timeout(timeout_duration, reply_rx)
.await
.with_context(|| timeout_message("connection_preflight", timeout_duration))?
.context("receive connection_preflight reply")??;
Ok(())
}

pub(crate) fn connection(&self, conn_id: &str) -> Option<ConnHandle> {
self.0.connections.read().get(conn_id).cloned()
}
Expand Down
8 changes: 7 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,17 @@ pub enum ActorEvent {
request: Option<Request>,
reply: Reply<()>,
},
ConnectionOpen {
ConnectionPreflight {
conn: ConnHandle,
params: Vec<u8>,
request: Option<Request>,
reply: Reply<()>,
},
ConnectionOpen {
conn: ConnHandle,
request: Option<Request>,
reply: Reply<()>,
},
ConnectionClosed {
conn: ConnHandle,
},
Expand Down Expand Up @@ -342,6 +347,7 @@ impl ActorEvent {
Self::HttpRequest { .. } => "http_request",
Self::QueueSend { .. } => "queue_send",
Self::WebSocketOpen { .. } => "websocket_open",
Self::ConnectionPreflight { .. } => "connection_preflight",
Self::ConnectionOpen { .. } => "connection_open",
Self::ConnectionClosed { .. } => "connection_closed",
Self::SubscribeRequest { .. } => "subscribe_request",
Expand Down
100 changes: 99 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/tests/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod moved_tests {
use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use parking_lot::Mutex;
use tokio::sync::{Barrier, mpsc};
Expand Down Expand Up @@ -70,6 +71,101 @@ mod moved_tests {
assert!(ctx.connection("conn-preloaded").is_some());
}

#[tokio::test]
async fn pending_connection_is_invisible_until_preflight_succeeds() {
let ctx = ActorContext::new_with_kv(
"actor-preflight-visibility",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default());
let (events_tx, mut events_rx) = mpsc::unbounded_channel();
ctx.configure_actor_events(Some(events_tx));

let events_ctx = ctx.clone();
let event_task = tokio::spawn(async move {
let preflight_conn_id = match events_rx.recv().await.expect("preflight event") {
ActorEvent::ConnectionPreflight { conn, reply, .. } => {
assert!(events_ctx.connection(conn.id()).is_none());
conn.set_state_initial(vec![7]);
let conn_id = conn.id().to_owned();
reply.send(Ok(()));
conn_id
}
other => panic!("unexpected event: {other:?}"),
};

match events_rx.recv().await.expect("open event") {
ActorEvent::ConnectionOpen { conn, reply, .. } => {
assert_eq!(conn.id(), preflight_conn_id);
let visible = events_ctx
.connection(conn.id())
.expect("connection should be visible for onConnect");
assert_eq!(visible.state(), vec![7]);
reply.send(Ok(()));
}
other => panic!("unexpected event: {other:?}"),
}
});

let conn = ctx
.connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) })
.await
.expect("connection should succeed");

assert_eq!(conn.state(), vec![7]);
assert!(ctx.connection(conn.id()).is_some());
event_task.await.expect("event task should complete");
}

#[tokio::test]
async fn failed_preflight_never_exposes_connection() {
let ctx = ActorContext::new_with_kv(
"actor-preflight-failure",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default());
let (events_tx, mut events_rx) = mpsc::unbounded_channel();
ctx.configure_actor_events(Some(events_tx));
let failed_conn_id = Arc::new(Mutex::new(None::<String>));

let events_ctx = ctx.clone();
let event_failed_conn_id = failed_conn_id.clone();
let event_task = tokio::spawn(async move {
match events_rx.recv().await.expect("preflight event") {
ActorEvent::ConnectionPreflight { conn, reply, .. } => {
assert!(events_ctx.connection(conn.id()).is_none());
*event_failed_conn_id.lock() = Some(conn.id().to_owned());
reply.send(Err(anyhow::anyhow!("reject preflight")));
}
other => panic!("unexpected event: {other:?}"),
}
assert!(
tokio::time::timeout(Duration::from_millis(20), events_rx.recv())
.await
.is_err()
);
});

let error = ctx
.connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) })
.await
.expect_err("connection should fail");

assert!(format!("{error:#}").contains("reject preflight"));
let conn_id = failed_conn_id
.lock()
.clone()
.expect("failed connection id should be recorded");
assert!(ctx.connection(&conn_id).is_none());
event_task.await.expect("event task should complete");
}

#[test]
fn persisted_connection_uses_ts_v4_fixed_id_wire_format() {
let persisted = PersistedConnection {
Expand Down Expand Up @@ -132,6 +228,7 @@ mod moved_tests {
async move {
while let Some(event) = events_rx.recv().await {
match event {
ActorEvent::ConnectionPreflight { reply, .. } => reply.send(Ok(())),
ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())),
ActorEvent::ConnectionClosed { conn } => {
*observed_conn_id.lock() = Some(conn.id().to_owned());
Expand Down Expand Up @@ -218,12 +315,13 @@ mod moved_tests {
ctx.configure_lifecycle_events(Some(lifecycle_events_tx));

let open_replies = tokio::spawn(async move {
for _ in 0..2 {
for _ in 0..4 {
match actor_events_rx
.recv()
.await
.expect("open event should arrive")
{
ActorEvent::ConnectionPreflight { reply, .. } => reply.send(Ok(())),
ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())),
other => panic!("unexpected actor event: {other:?}"),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,17 @@ fn counter_factory() -> ActorFactory {
} => {
reply.send(Err(anyhow::anyhow!("websockets are not handled")));
}
ActorEvent::ConnectionOpen {
ActorEvent::ConnectionPreflight {
conn: _,
params: _,
request: _,
reply,
} => {
reply.send(Ok(()));
}
ActorEvent::ConnectionOpen { reply, .. } => {
reply.send(Ok(()));
}
ActorEvent::ConnectionClosed { conn: _ } => {}
ActorEvent::SubscribeRequest {
conn: _,
Expand Down
9 changes: 6 additions & 3 deletions rivetkit-rust/packages/rivetkit-core/tests/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,8 @@ mod moved_tests {
reply.send(Ok(()));
break;
}
ActorEvent::ConnectionOpen { .. } => {
ActorEvent::ConnectionPreflight { .. }
| ActorEvent::ConnectionOpen { .. } => {
panic!("hibernated connection should not refire ConnectionOpen");
}
_ => {}
Expand Down Expand Up @@ -3982,7 +3983,8 @@ mod moved_tests {
reply.send(Ok(()));
break;
}
ActorEvent::ConnectionOpen { .. } => {
ActorEvent::ConnectionPreflight { .. }
| ActorEvent::ConnectionOpen { .. } => {
panic!("hibernated connection should not refire ConnectionOpen");
}
_ => {}
Expand Down Expand Up @@ -4107,7 +4109,8 @@ mod moved_tests {
reply.send(Ok(()));
break;
}
ActorEvent::ConnectionOpen { .. } => {
ActorEvent::ConnectionPreflight { .. }
| ActorEvent::ConnectionOpen { .. } => {
panic!("dead hibernated connection should not refire ConnectionOpen");
}
_ => {}
Expand Down
5 changes: 4 additions & 1 deletion rivetkit-rust/packages/rivetkit/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<A: Actor> Event<A> {
reply: Some(reply),
_p: PhantomData,
}),
ActorEvent::ConnectionOpen {
ActorEvent::ConnectionPreflight {
conn,
params,
request,
Expand All @@ -92,6 +92,9 @@ impl<A: Actor> Event<A> {
request,
reply: Some(reply),
}),
ActorEvent::ConnectionOpen { .. } => {
unreachable!("ConnectionOpen is handled by Events")
}
ActorEvent::ConnectionClosed { conn } => Self::ConnClosed(ConnClosed {
conn: ConnCtx::from(conn),
}),
Expand Down
8 changes: 8 additions & 0 deletions rivetkit-rust/packages/rivetkit/src/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl<A: Actor> Events<A> {

async fn handle_runtime_event(&self, event: ActorEvent) -> Option<ActorEvent> {
match event {
ActorEvent::ConnectionOpen { reply, .. } => {
reply.send(Ok(()));
None
}
ActorEvent::DisconnectConn { conn_id, reply } => {
reply.send(self.ctx.disconnect_conn(&conn_id).await);
None
Expand All @@ -137,6 +141,10 @@ impl<A: Actor> Events<A> {

fn handle_runtime_event_sync(&self, event: ActorEvent) -> Option<ActorEvent> {
match event {
ActorEvent::ConnectionOpen { reply, .. } => {
reply.send(Ok(()));
None
}
ActorEvent::DisconnectConn { conn_id, reply } => {
let ctx = self.ctx.clone();
tokio::spawn(async move {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,17 +502,15 @@ pub(crate) async fn dispatch_event(
call_on_websocket(&callback, &ctx, conn, ws, request).await
});
}
ActorEvent::ConnectionOpen {
ActorEvent::ConnectionPreflight {
conn,
params,
request,
reply,
} => {
let on_before_connect = bindings.on_before_connect.clone();
let create_conn_state = bindings.create_conn_state.clone();
let on_connect = bindings.on_connect.clone();
let timeout = config.on_before_connect_timeout;
let connect_timeout = config.on_connect_timeout;
let create_conn_state_timeout = config.create_conn_state_timeout;
let ctx = ctx.clone();

Expand Down Expand Up @@ -542,6 +540,19 @@ pub(crate) async fn dispatch_event(
ctx.set_conn_state_initial(&conn, state)?;
}

Ok(())
});
}
ActorEvent::ConnectionOpen {
conn,
request,
reply,
} => {
let on_connect = bindings.on_connect.clone();
let connect_timeout = config.on_connect_timeout;
let ctx = ctx.clone();

spawn_reply(tasks, abort.clone(), reply, async move {
if let Some(callback) = on_connect {
with_timeout(
"onConnect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ mod moved_tests {
let conn = rivetkit_core::ConnHandle::new("conn-open", vec![1, 2, 3], Vec::new(), false);

dispatch_event(
ActorEvent::ConnectionOpen {
ActorEvent::ConnectionPreflight {
conn: conn.clone(),
params: vec![4, 5, 6],
request: None,
Expand Down
Loading
Loading