@@ -8,7 +8,7 @@ use crate::auth::{
88} ;
99use crate :: routes:: subscribe:: generate_random_connection_id;
1010use crate :: util:: { ByteStringBody , NameOrIdentity } ;
11- use crate :: { log_and_500, ControlStateDelegate , DatabaseDef , NodeDelegate } ;
11+ use crate :: { log_and_500, ControlStateDelegate , DatabaseDef , Host , NodeDelegate } ;
1212use axum:: body:: { Body , Bytes } ;
1313use axum:: extract:: { Path , Query , State } ;
1414use axum:: response:: { ErrorResponse , IntoResponse } ;
@@ -20,16 +20,16 @@ use http::StatusCode;
2020use serde:: Deserialize ;
2121use spacetimedb:: database_logger:: DatabaseLogger ;
2222use spacetimedb:: host:: module_host:: ClientConnectedError ;
23- use spacetimedb:: host:: ReducerArgs ;
2423use spacetimedb:: host:: ReducerCallError ;
2524use spacetimedb:: host:: ReducerOutcome ;
2625use spacetimedb:: host:: UpdateDatabaseResult ;
26+ use spacetimedb:: host:: { ModuleHost , ReducerArgs } ;
2727use spacetimedb:: identity:: Identity ;
2828use spacetimedb:: messages:: control_db:: { Database , HostType } ;
2929use spacetimedb_client_api_messages:: name:: { self , DatabaseName , DomainName , PublishOp , PublishResult } ;
3030use spacetimedb_lib:: db:: raw_def:: v9:: RawModuleDefV9 ;
3131use spacetimedb_lib:: identity:: AuthCtx ;
32- use spacetimedb_lib:: sats;
32+ use spacetimedb_lib:: { sats, ConnectionId } ;
3333
3434use super :: subscribe:: handle_websocket;
3535
@@ -41,22 +41,20 @@ pub struct CallParams {
4141
4242pub const NO_SUCH_DATABASE : ( StatusCode , & str ) = ( StatusCode :: NOT_FOUND , "No such database." ) ;
4343
44- pub async fn call < S : ControlStateDelegate + NodeDelegate > (
45- State ( worker_ctx) : State < S > ,
46- Extension ( auth) : Extension < SpacetimeAuth > ,
47- Path ( CallParams {
48- name_or_identity,
49- reducer,
50- } ) : Path < CallParams > ,
51- TypedHeader ( content_type) : TypedHeader < headers:: ContentType > ,
52- ByteStringBody ( body) : ByteStringBody ,
53- ) -> axum:: response:: Result < impl IntoResponse > {
54- if content_type != headers:: ContentType :: json ( ) {
55- return Err ( axum:: extract:: rejection:: MissingJsonContentType :: default ( ) . into ( ) ) ;
56- }
57- let caller_identity = auth. identity ;
44+ struct Connected {
45+ database : Database ,
46+ leader : Host ,
47+ module : ModuleHost ,
48+ connection_id : ConnectionId ,
49+ caller_identity : Identity ,
50+ }
5851
59- let args = ReducerArgs :: Json ( body) ;
52+ async fn call_on_connect < S : ControlStateDelegate + NodeDelegate > (
53+ worker_ctx : S ,
54+ auth : SpacetimeAuth ,
55+ name_or_identity : NameOrIdentity ,
56+ ) -> axum:: response:: Result < Connected > {
57+ let caller_identity = auth. identity ;
6058
6159 let db_identity = name_or_identity. resolve ( & worker_ctx) . await ?;
6260 let database = worker_ctx_find_database ( & worker_ctx, & db_identity)
@@ -65,7 +63,6 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
6563 log:: error!( "Could not find database: {}" , db_identity. to_hex( ) ) ;
6664 NO_SUCH_DATABASE
6765 } ) ?;
68- let identity = database. owner_identity ;
6966
7067 let leader = worker_ctx
7168 . leader ( database. id )
@@ -81,32 +78,75 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
8178 match module. call_identity_connected ( caller_identity, connection_id) . await {
8279 // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored,
8380 // meaning the connection was refused. Return 403 forbidden.
84- Err ( ClientConnectedError :: Rejected ( msg) ) => return Err ( ( StatusCode :: FORBIDDEN , msg) . into ( ) ) ,
81+ Err ( ClientConnectedError :: Rejected ( msg) ) => Err ( ( StatusCode :: FORBIDDEN , msg) . into ( ) ) ,
8582 // If `call_identity_connected` returns `Err(OutOfEnergy)`,
8683 // then, well, the database is out of energy.
8784 // Return 503 service unavailable.
88- Err ( err @ ClientConnectedError :: OutOfEnergy ) => {
89- return Err ( ( StatusCode :: SERVICE_UNAVAILABLE , err. to_string ( ) ) . into ( ) )
90- }
85+ Err ( err @ ClientConnectedError :: OutOfEnergy ) => Err ( ( StatusCode :: SERVICE_UNAVAILABLE , err. to_string ( ) ) . into ( ) ) ,
9186 // If `call_identity_connected` returns `Err(ReducerCall)`,
9287 // something went wrong while invoking the `client_connected` reducer.
9388 // I (pgoldman 2025-03-27) am not really sure how this would happen,
9489 // but we returned 404 not found in this case prior to my editing this code,
9590 // so I guess let's keep doing that.
9691 Err ( ClientConnectedError :: ReducerCall ( e) ) => {
97- return Err ( ( StatusCode :: NOT_FOUND , format ! ( "{:#}" , anyhow:: anyhow!( e) ) ) . into ( ) )
92+ Err ( ( StatusCode :: NOT_FOUND , format ! ( "{:#}" , anyhow:: anyhow!( e) ) ) . into ( ) )
9893 }
9994 // If `call_identity_connected` returns `Err(DBError)`,
10095 // then the module didn't define `client_connected`,
10196 // but something went wrong when we tried to insert into `st_client`.
10297 // That's weird and scary, so return 500 internal error.
103- Err ( e @ ClientConnectedError :: DBError ( _) ) => {
104- return Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , e. to_string ( ) ) . into ( ) )
105- }
98+ Err ( e @ ClientConnectedError :: DBError ( _) ) => Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , e. to_string ( ) ) . into ( ) ) ,
10699
107100 // If `call_identity_connected` returns `Ok`, then we can actually call the reducer we want.
108- Ok ( ( ) ) => ( ) ,
101+ Ok ( ( ) ) => Ok ( Connected {
102+ caller_identity,
103+ database,
104+ leader,
105+ module,
106+ connection_id,
107+ } ) ,
109108 }
109+ }
110+
111+ async fn call_on_disconnect (
112+ module : ModuleHost ,
113+ connection_id : ConnectionId ,
114+ caller_identity : Identity ,
115+ ) -> axum:: response:: Result < ( ) > {
116+ if let Err ( e) = module. call_identity_disconnected ( caller_identity, connection_id) . await {
117+ // If `call_identity_disconnected` errors, something is very wrong:
118+ // it means we tried to delete the `st_client` row but failed.
119+ // Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
120+ // Slap a 500 on it and pray.
121+ return Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "{:#}" , anyhow:: anyhow!( e) ) ) . into ( ) ) ;
122+ }
123+ Ok ( ( ) )
124+ }
125+
126+ pub async fn call < S : ControlStateDelegate + NodeDelegate > (
127+ State ( worker_ctx) : State < S > ,
128+ Extension ( auth) : Extension < SpacetimeAuth > ,
129+ Path ( CallParams {
130+ name_or_identity,
131+ reducer,
132+ } ) : Path < CallParams > ,
133+ TypedHeader ( content_type) : TypedHeader < headers:: ContentType > ,
134+ ByteStringBody ( body) : ByteStringBody ,
135+ ) -> axum:: response:: Result < impl IntoResponse > {
136+ if content_type != headers:: ContentType :: json ( ) {
137+ return Err ( axum:: extract:: rejection:: MissingJsonContentType :: default ( ) . into ( ) ) ;
138+ }
139+
140+ let Connected {
141+ caller_identity,
142+ database,
143+ leader : _,
144+ module,
145+ connection_id,
146+ } = call_on_connect ( worker_ctx, auth, name_or_identity) . await ?;
147+
148+ let args = ReducerArgs :: Json ( body) ;
149+
110150 let result = match module
111151 . call_reducer ( caller_identity, Some ( connection_id) , None , None , None , & reducer, args)
112152 . await
@@ -134,17 +174,11 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
134174 }
135175 } ;
136176
137- if let Err ( e) = module. call_identity_disconnected ( caller_identity, connection_id) . await {
138- // If `call_identity_disconnected` errors, something is very wrong:
139- // it means we tried to delete the `st_client` row but failed.
140- // Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
141- // Slap a 500 on it and pray.
142- return Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "{:#}" , anyhow:: anyhow!( e) ) ) . into ( ) ) ;
143- }
177+ call_on_disconnect ( module, connection_id, caller_identity) . await ?;
144178
145179 match result {
146180 Ok ( result) => {
147- let ( status, body) = reducer_outcome_response ( & identity , & reducer, result. outcome ) ;
181+ let ( status, body) = reducer_outcome_response ( & database . owner_identity , & reducer, result. outcome ) ;
148182 Ok ( (
149183 status,
150184 TypedHeader ( SpacetimeEnergyUsed ( result. energy_used ) ) ,
@@ -400,22 +434,21 @@ where
400434{
401435 // Anyone is authorized to execute SQL queries. The SQL engine will determine
402436 // which queries this identity is allowed to execute against the database.
403-
404- let db_identity = name_or_identity. resolve ( & worker_ctx) . await ?;
405- let database = worker_ctx_find_database ( & worker_ctx, & db_identity)
406- . await ?
407- . ok_or ( NO_SUCH_DATABASE ) ?;
408-
409- let auth = AuthCtx :: new ( database. owner_identity , auth. identity ) ;
437+ let Connected {
438+ caller_identity,
439+ database,
440+ leader,
441+ module,
442+ connection_id,
443+ } = call_on_connect ( worker_ctx, auth, name_or_identity) . await ?;
444+
445+ let auth = AuthCtx :: new ( database. owner_identity , caller_identity) ;
410446 log:: debug!( "auth: {auth:?}" ) ;
411447
412- let host = worker_ctx
413- . leader ( database. id )
414- . await
415- . map_err ( log_and_500) ?
416- . ok_or ( StatusCode :: NOT_FOUND ) ?;
417- let json = host. exec_sql ( auth, database, body) . await ?;
418-
448+ // Notify the disconnect even if the SQL execution fails.
449+ let json_result = leader. exec_sql ( auth, database, body) . await ;
450+ call_on_disconnect ( module, connection_id, caller_identity) . await ?;
451+ let json = json_result?;
419452 let total_duration = json. iter ( ) . fold ( 0 , |acc, x| acc + x. total_duration_micros ) ;
420453
421454 Ok ( (
@@ -766,9 +799,10 @@ where
766799 names_put : put ( set_names :: < S > ) ,
767800 identity_get : get ( get_identity :: < S > ) ,
768801 subscribe_get : get ( handle_websocket :: < S > ) ,
769- call_reducer_post : post ( call :: < S > ) ,
770802 schema_get : get ( schema :: < S > ) ,
771803 logs_get : get ( logs :: < S > ) ,
804+ // Need calls to on_connect and on_disconnect...
805+ call_reducer_post : post ( call :: < S > ) ,
772806 sql_post : post ( sql :: < S > ) ,
773807 }
774808 }
0 commit comments