diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs index 970c517423..4ea57aedd8 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs @@ -285,6 +285,7 @@ pub enum ActorEvent { reply: Reply, }, WebSocketOpen { + conn: ConnHandle, ws: WebSocket, request: Option, reply: Reply<()>, diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs index 8dff4a424d..e4df478cc9 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs @@ -282,6 +282,7 @@ pub enum DispatchCommand { reply: oneshot::Sender, }, OpenWebSocket { + conn: ConnHandle, ws: WebSocket, request: Option, reply: oneshot::Sender>, @@ -1031,10 +1032,16 @@ impl ActorTask { } } } - DispatchCommand::OpenWebSocket { ws, request, reply } => { + DispatchCommand::OpenWebSocket { + conn, + ws, + request, + reply, + } => { match self.send_actor_event( "dispatch_websocket_open", ActorEvent::WebSocketOpen { + conn, ws, request, reply: Reply::from(reply), diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs index d642e2a9c9..2cbc97cad8 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs @@ -81,6 +81,7 @@ where pub(super) async fn dispatch_websocket_open_through_task( dispatch: &mpsc::Sender, capacity: usize, + conn: ConnHandle, ws: WebSocket, request: Option, ) -> Result<()> { @@ -90,6 +91,7 @@ pub(super) async fn dispatch_websocket_open_through_task( capacity, "dispatch_websocket_open", DispatchCommand::OpenWebSocket { + conn, ws, request, reply: reply_tx, diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs index 0be7916a97..365e8e710f 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs @@ -546,6 +546,7 @@ impl RegistryDispatcher { let dispatch_capacity = instance.factory.config().dispatch_command_inbox_capacity; let conn_for_close = conn.clone(); let conn_for_message = conn.clone(); + let conn_for_open = conn.clone(); let ctx_for_message = ctx.clone(); let ctx_for_close = ctx.clone(); let ws = WebSocket::new(); @@ -657,6 +658,7 @@ impl RegistryDispatcher { }), on_open: Some(Box::new(move |sender| { let request = request_for_open.clone(); + let conn = conn_for_open.clone(); let ws = ws_for_open.clone(); let actor_id = actor_id_for_open.clone(); let dispatch = dispatch.clone(); @@ -666,6 +668,7 @@ impl RegistryDispatcher { let result = dispatch_websocket_open_through_task( &dispatch, dispatch_capacity, + conn, ws.clone(), Some(request), ) diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/actor_factory.rs b/rivetkit-typescript/packages/rivetkit-napi/src/actor_factory.rs index 8e4411cf42..2603fc37a1 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/actor_factory.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/actor_factory.rs @@ -141,6 +141,7 @@ pub(crate) struct QueueSendPayload { #[derive(Clone)] pub(crate) struct WebSocketPayload { pub(crate) ctx: CoreActorContext, + pub(crate) conn: CoreConnHandle, pub(crate) ws: CoreWebSocket, pub(crate) request: Option, } @@ -833,6 +834,7 @@ fn build_websocket_payload( ) -> napi::Result> { let mut object = env.create_object()?; object.set("ctx", ActorContext::new(payload.ctx))?; + object.set("conn", ConnHandle::new(payload.conn))?; object.set("ws", WebSocket::new(payload.ws))?; if let Some(request) = payload.request { object.set("request", build_request_object(env, request)?)?; diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs b/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs index abfa4ce7b3..049f0625d9 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs @@ -474,14 +474,19 @@ pub(crate) async fn dispatch_event( .await }); } - ActorEvent::WebSocketOpen { ws, request, reply } => { + ActorEvent::WebSocketOpen { + conn, + ws, + request, + reply, + } => { let Some(callback) = bindings.on_websocket.clone() else { reply.send(Ok(())); return; }; let ctx = ctx.clone(); spawn_reply(tasks, abort.clone(), reply, async move { - call_on_websocket(&callback, &ctx, ws, request).await + call_on_websocket(&callback, &ctx, conn, ws, request).await }); } ActorEvent::ConnectionOpen { @@ -1118,6 +1123,7 @@ where async fn call_on_websocket( callback: &crate::actor_factory::CallbackTsfn, ctx: &ActorContext, + conn: rivetkit_core::ConnHandle, ws: rivetkit_core::WebSocket, request: Option, ) -> Result<()> { @@ -1126,6 +1132,7 @@ async fn call_on_websocket( callback, WebSocketPayload { ctx: ctx.inner().clone(), + conn, ws, request, }, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts index 884ddd83f2..1c7254dcea 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts @@ -186,3 +186,21 @@ export const rawWebSocketAsyncOpenActor = actor({ getOpenCount: (ctx) => ctx.state.openCount, }, }); + +export const rawWebSocketConnContextActor = actor({ + onWebSocket(ctx: any, websocket: UniversalWebSocket) { + const connId = ctx.conn.id; + ctx.conn.state = { + opened: true, + connId, + }; + websocket.send( + JSON.stringify({ + type: "conn-context", + connId, + state: ctx.conn.state, + }), + ); + }, + actions: {}, +}); diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts index c58fa23abf..722eb42158 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts @@ -75,6 +75,7 @@ import { rawWebSocketActor, rawWebSocketAsyncOpenActor, rawWebSocketBinaryActor, + rawWebSocketConnContextActor, } from "./raw-websocket"; import { rejectConnectionActor } from "./reject-connection"; import { requestAccessActor } from "./request-access"; @@ -268,6 +269,7 @@ export const registry = setup({ rawWebSocketActor, rawWebSocketAsyncOpenActor, rawWebSocketBinaryActor, + rawWebSocketConnContextActor, // From reject-connection.ts rejectConnectionActor, // From request-access.ts diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts index ff8bef74ab..f91564108e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts @@ -4134,6 +4134,7 @@ export function buildNativeFactory( error: unknown, payload: { ctx: NativeActorContext; + conn: NativeConnHandle; ws: NativeWebSocket; request?: { method: string; @@ -4143,14 +4144,19 @@ export function buildNativeFactory( }; }, ) => { - const { ctx, ws, request } = unwrapTsfnPayload( - error, - payload, - ); + const { ctx, conn, ws, request } = + unwrapTsfnPayload( + error, + payload, + ); const jsRequest = request ? buildRequest(request) : undefined; - const actorCtx = makeActorCtx(ctx, jsRequest); + const actorCtx = makeConnCtx( + ctx, + conn, + jsRequest, + ); try { await config.onWebSocket( actorCtx, diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver/raw-websocket.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver/raw-websocket.test.ts index 3b575b099a..5242f55c04 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/driver/raw-websocket.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/driver/raw-websocket.test.ts @@ -452,6 +452,25 @@ describeDriverMatrix("Raw Websocket", (driverTestConfig) => { ws.close(); }); + test("should expose connection context in onWebSocket", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const actor = client.rawWebSocketConnContextActor.getOrCreate([ + "conn-context", + ]); + + const ws = await actor.webSocket(); + const message = await waitForJsonMessage(ws, 5_000); + + expect(message?.type).toBe("conn-context"); + expect(typeof message?.connId).toBe("string"); + expect(message?.state).toEqual({ + opened: true, + connId: message?.connId, + }); + + ws.close(); + }); + test("should properly handle onWebSocket open and close events", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); const actor = client.rawWebSocketActor.getOrCreate([