diff --git a/rivetkit-rust/packages/rivetkit-core/tests/connection.rs b/rivetkit-rust/packages/rivetkit-core/tests/connection.rs index 9524c4ee2a..f049e4dfba 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/connection.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/connection.rs @@ -166,6 +166,49 @@ mod moved_tests { event_task.await.expect("event task should complete"); } + #[tokio::test] + async fn transport_close_during_preflight_never_emits_connection_closed() { + let ctx = ActorContext::new_with_kv( + "actor-preflight-transport-close", + "actor", + Vec::new(), + "local", + Kv::new_in_memory(), + ); + ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default()); + let (events_tx, mut events_rx) = mpsc::unbounded_channel(); + ctx.configure_actor_events(Some(events_tx)); + let closed_conn_id = Arc::new(Mutex::new(None::)); + + let event_closed_conn_id = closed_conn_id.clone(); + let event_task = tokio::spawn(async move { + match events_rx.recv().await.expect("preflight event") { + ActorEvent::ConnectionPreflight { conn, reply, .. } => { + conn.disconnect(Some("transport closed")) + .await + .expect("pending connection transport close should succeed"); + reply.send(Err(anyhow::anyhow!("reject after transport close"))); + } + other => panic!("unexpected event: {other:?}"), + } + + if let Ok(Some(ActorEvent::ConnectionClosed { conn })) = + tokio::time::timeout(Duration::from_millis(20), events_rx.recv()).await + { + *event_closed_conn_id.lock() = Some(conn.id().to_owned()); + } + }); + + let error = ctx + .connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) }) + .await + .expect_err("connection should fail"); + + assert!(format!("{error:#}").contains("reject after transport close")); + assert_eq!(*closed_conn_id.lock(), None); + event_task.await.expect("event task should complete"); + } + #[test] fn persisted_connection_uses_ts_v4_fixed_id_wire_format() { let persisted = PersistedConnection { diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-preflight-visibility.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-preflight-visibility.ts index 2f55906f2e..3ce8d361ac 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-preflight-visibility.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-preflight-visibility.ts @@ -8,6 +8,8 @@ type ConnParams = { label?: string; beforeDelayMs?: number; createDelayMs?: number; + rejectBefore?: boolean; + rejectCreate?: boolean; }; const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); @@ -38,6 +40,9 @@ export const connPreflightVisibilityActor = actor({ if (params?.beforeDelayMs) { await sleep(params.beforeDelayMs); } + if (params?.rejectBefore) { + throw new Error("rejected before connect"); + } }, createConnState: async (c, params: ConnParams): Promise => { c.state.createStarted += 1; @@ -45,6 +50,9 @@ export const connPreflightVisibilityActor = actor({ if (params?.createDelayMs) { await sleep(params.createDelayMs); } + if (params?.rejectCreate) { + throw new Error("rejected create conn state"); + } return { label: params?.label ?? "anonymous" }; }, onConnect: (c, conn) => { diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver/conn-preflight-disconnect.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver/conn-preflight-disconnect.test.ts new file mode 100644 index 0000000000..993a41a137 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/tests/driver/conn-preflight-disconnect.test.ts @@ -0,0 +1,43 @@ +// @ts-nocheck + +import { expect, test } from "vitest"; +import { describeDriverMatrix } from "./shared-matrix"; +import { setupDriverTest } from "./shared-utils"; + +describeDriverMatrix( + "Connection Preflight Disconnect", + (driverTestConfig) => { + test("should not call onDisconnect when preflight fails", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const handle = client.connPreflightVisibilityActor.getOrCreate([ + "failed-preflight-disconnect", + crypto.randomUUID(), + ]); + const primary = handle.connect({ label: "primary" }); + await primary.snapshot(); + + const rejectedBefore = handle.connect({ + label: "rejected-before", + rejectBefore: true, + }); + await expect(rejectedBefore.snapshot()).rejects.toThrow(); + + const rejectedCreate = handle.connect({ + label: "rejected-create", + rejectCreate: true, + }); + await expect(rejectedCreate.snapshot()).rejects.toThrow(); + + const snapshot = await primary.snapshot(); + expect(snapshot.disconnectSnapshots).toEqual([]); + expect(snapshot.visibleLabels).toEqual(["primary"]); + + await primary.dispose(); + }); + }, + { + encodings: ["bare"], + runtimes: ["wasm"], + sqliteBackends: ["remote"], + }, +);