Skip to content

Commit 429e1fc

Browse files
committed
Invoke __identity_connected__ for sql http requests
1 parent e9a9de7 commit 429e1fc

3 files changed

Lines changed: 99 additions & 55 deletions

File tree

crates/client-api/src/routes/database.rs

Lines changed: 85 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::auth::{
88
};
99
use crate::routes::subscribe::generate_random_connection_id;
1010
use crate::util::{ByteStringBody, NameOrIdentity};
11-
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate};
11+
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, Host, NodeDelegate};
1212
use axum::body::{Body, Bytes};
1313
use axum::extract::{Path, Query, State};
1414
use axum::response::{ErrorResponse, IntoResponse};
@@ -20,16 +20,16 @@ use http::StatusCode;
2020
use serde::Deserialize;
2121
use spacetimedb::database_logger::DatabaseLogger;
2222
use spacetimedb::host::module_host::ClientConnectedError;
23-
use spacetimedb::host::ReducerArgs;
2423
use spacetimedb::host::ReducerCallError;
2524
use spacetimedb::host::ReducerOutcome;
2625
use spacetimedb::host::UpdateDatabaseResult;
26+
use spacetimedb::host::{ModuleHost, ReducerArgs};
2727
use spacetimedb::identity::Identity;
2828
use spacetimedb::messages::control_db::{Database, HostType};
2929
use spacetimedb_client_api_messages::name::{self, DatabaseName, DomainName, PublishOp, PublishResult};
3030
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
3131
use spacetimedb_lib::identity::AuthCtx;
32-
use spacetimedb_lib::sats;
32+
use spacetimedb_lib::{sats, ConnectionId};
3333

3434
use super::subscribe::handle_websocket;
3535

@@ -41,22 +41,20 @@ pub struct CallParams {
4141

4242
pub const NO_SUCH_DATABASE: (StatusCode, &str) = (StatusCode::NOT_FOUND, "No such database.");
4343

44-
pub async fn call<S: ControlStateDelegate + NodeDelegate>(
45-
State(worker_ctx): State<S>,
46-
Extension(auth): Extension<SpacetimeAuth>,
47-
Path(CallParams {
48-
name_or_identity,
49-
reducer,
50-
}): Path<CallParams>,
51-
TypedHeader(content_type): TypedHeader<headers::ContentType>,
52-
ByteStringBody(body): ByteStringBody,
53-
) -> axum::response::Result<impl IntoResponse> {
54-
if content_type != headers::ContentType::json() {
55-
return Err(axum::extract::rejection::MissingJsonContentType::default().into());
56-
}
57-
let caller_identity = auth.identity;
44+
struct Connected {
45+
database: Database,
46+
leader: Host,
47+
module: ModuleHost,
48+
connection_id: ConnectionId,
49+
caller_identity: Identity,
50+
}
5851

59-
let args = ReducerArgs::Json(body);
52+
async fn call_on_connect<S: ControlStateDelegate + NodeDelegate>(
53+
worker_ctx: S,
54+
auth: SpacetimeAuth,
55+
name_or_identity: NameOrIdentity,
56+
) -> axum::response::Result<Connected> {
57+
let caller_identity = auth.identity;
6058

6159
let db_identity = name_or_identity.resolve(&worker_ctx).await?;
6260
let database = worker_ctx_find_database(&worker_ctx, &db_identity)
@@ -65,7 +63,6 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
6563
log::error!("Could not find database: {}", db_identity.to_hex());
6664
NO_SUCH_DATABASE
6765
})?;
68-
let identity = database.owner_identity;
6966

7067
let leader = worker_ctx
7168
.leader(database.id)
@@ -81,32 +78,75 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
8178
match module.call_identity_connected(caller_identity, connection_id).await {
8279
// If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored,
8380
// meaning the connection was refused. Return 403 forbidden.
84-
Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()),
81+
Err(ClientConnectedError::Rejected(msg)) => Err((StatusCode::FORBIDDEN, msg).into()),
8582
// If `call_identity_connected` returns `Err(OutOfEnergy)`,
8683
// then, well, the database is out of energy.
8784
// Return 503 service unavailable.
88-
Err(err @ ClientConnectedError::OutOfEnergy) => {
89-
return Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into())
90-
}
85+
Err(err @ ClientConnectedError::OutOfEnergy) => Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into()),
9186
// If `call_identity_connected` returns `Err(ReducerCall)`,
9287
// something went wrong while invoking the `client_connected` reducer.
9388
// I (pgoldman 2025-03-27) am not really sure how this would happen,
9489
// but we returned 404 not found in this case prior to my editing this code,
9590
// so I guess let's keep doing that.
9691
Err(ClientConnectedError::ReducerCall(e)) => {
97-
return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into())
92+
Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into())
9893
}
9994
// If `call_identity_connected` returns `Err(DBError)`,
10095
// then the module didn't define `client_connected`,
10196
// but something went wrong when we tried to insert into `st_client`.
10297
// That's weird and scary, so return 500 internal error.
103-
Err(e @ ClientConnectedError::DBError(_)) => {
104-
return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into())
105-
}
98+
Err(e @ ClientConnectedError::DBError(_)) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into()),
10699

107100
// If `call_identity_connected` returns `Ok`, then we can actually call the reducer we want.
108-
Ok(()) => (),
101+
Ok(()) => Ok(Connected {
102+
caller_identity,
103+
database,
104+
leader,
105+
module,
106+
connection_id,
107+
}),
109108
}
109+
}
110+
111+
async fn call_on_disconnect(
112+
module: ModuleHost,
113+
connection_id: ConnectionId,
114+
caller_identity: Identity,
115+
) -> axum::response::Result<()> {
116+
if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await {
117+
// If `call_identity_disconnected` errors, something is very wrong:
118+
// it means we tried to delete the `st_client` row but failed.
119+
// Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
120+
// Slap a 500 on it and pray.
121+
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into());
122+
}
123+
Ok(())
124+
}
125+
126+
pub async fn call<S: ControlStateDelegate + NodeDelegate>(
127+
State(worker_ctx): State<S>,
128+
Extension(auth): Extension<SpacetimeAuth>,
129+
Path(CallParams {
130+
name_or_identity,
131+
reducer,
132+
}): Path<CallParams>,
133+
TypedHeader(content_type): TypedHeader<headers::ContentType>,
134+
ByteStringBody(body): ByteStringBody,
135+
) -> axum::response::Result<impl IntoResponse> {
136+
if content_type != headers::ContentType::json() {
137+
return Err(axum::extract::rejection::MissingJsonContentType::default().into());
138+
}
139+
140+
let Connected {
141+
caller_identity,
142+
database,
143+
leader: _,
144+
module,
145+
connection_id,
146+
} = call_on_connect(worker_ctx, auth, name_or_identity).await?;
147+
148+
let args = ReducerArgs::Json(body);
149+
110150
let result = match module
111151
.call_reducer(caller_identity, Some(connection_id), None, None, None, &reducer, args)
112152
.await
@@ -134,17 +174,11 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
134174
}
135175
};
136176

137-
if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await {
138-
// If `call_identity_disconnected` errors, something is very wrong:
139-
// it means we tried to delete the `st_client` row but failed.
140-
// Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
141-
// Slap a 500 on it and pray.
142-
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into());
143-
}
177+
call_on_disconnect(module, connection_id, caller_identity).await?;
144178

145179
match result {
146180
Ok(result) => {
147-
let (status, body) = reducer_outcome_response(&identity, &reducer, result.outcome);
181+
let (status, body) = reducer_outcome_response(&database.owner_identity, &reducer, result.outcome);
148182
Ok((
149183
status,
150184
TypedHeader(SpacetimeEnergyUsed(result.energy_used)),
@@ -400,22 +434,21 @@ where
400434
{
401435
// Anyone is authorized to execute SQL queries. The SQL engine will determine
402436
// which queries this identity is allowed to execute against the database.
403-
404-
let db_identity = name_or_identity.resolve(&worker_ctx).await?;
405-
let database = worker_ctx_find_database(&worker_ctx, &db_identity)
406-
.await?
407-
.ok_or(NO_SUCH_DATABASE)?;
408-
409-
let auth = AuthCtx::new(database.owner_identity, auth.identity);
437+
let Connected {
438+
caller_identity,
439+
database,
440+
leader,
441+
module,
442+
connection_id,
443+
} = call_on_connect(worker_ctx, auth, name_or_identity).await?;
444+
445+
let auth = AuthCtx::new(database.owner_identity, caller_identity);
410446
log::debug!("auth: {auth:?}");
411447

412-
let host = worker_ctx
413-
.leader(database.id)
414-
.await
415-
.map_err(log_and_500)?
416-
.ok_or(StatusCode::NOT_FOUND)?;
417-
let json = host.exec_sql(auth, database, body).await?;
418-
448+
// Notify the disconnect even if the SQL execution fails.
449+
let json_result = leader.exec_sql(auth, database, body).await;
450+
call_on_disconnect(module, connection_id, caller_identity).await?;
451+
let json = json_result?;
419452
let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);
420453

421454
Ok((
@@ -766,9 +799,10 @@ where
766799
names_put: put(set_names::<S>),
767800
identity_get: get(get_identity::<S>),
768801
subscribe_get: get(handle_websocket::<S>),
769-
call_reducer_post: post(call::<S>),
770802
schema_get: get(schema::<S>),
771803
logs_get: get(logs::<S>),
804+
// Need calls to on_connect and on_disconnect...
805+
call_reducer_post: post(call::<S>),
772806
sql_post: post(sql::<S>),
773807
}
774808
}

smoketests/tests/client_connected_error_rejects_connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
}
1818
"""
1919

20+
2021
class ClientConnectedErrorRejectsConnection(Smoketest):
2122
MODULE_CODE = MODULE_HEADER + """
2223
@@ -33,12 +34,13 @@ class ClientConnectedErrorRejectsConnection(Smoketest):
3334

3435
def test_client_connected_error_rejects_connection(self):
3536
with self.assertRaises(Exception):
36-
self.subscribe("select * from all_u8s", n = 0)()
37+
self.subscribe("select * from all_u8s", n=0)()
3738

3839
logs = self.logs(100)
3940
self.assertIn('Rejecting connection from client', logs)
4041
self.assertNotIn('This should never be called, since we reject all connections!', logs)
4142

43+
4244
class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
4345
MODULE_CODE = MODULE_HEADER + """
4446
#[spacetimedb::reducer(client_connected)]
@@ -53,13 +55,12 @@ class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
5355
"""
5456

5557
def test_client_disconnected_error_still_deletes_st_client(self):
56-
self.subscribe("select * from all_u8s", n = 0)()
58+
self.subscribe("select * from all_u8s", n=0)()
5759

5860
logs = self.logs(100)
5961
self.assertIn('This should be called, but the `st_client` row should still be deleted', logs)
6062

6163
sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")
62-
6364
self.assertMultiLineEqual(sql_out, """ identity | connection_id
6465
----------+---------------
6566
""")

smoketests/tests/connect_disconnect_from_cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ConnDisconnFromCli(Smoketest):
2121
}
2222
"""
2323

24-
def test_conn_disconn(self):
24+
def test_conn_disconn_cli(self):
2525
"""
2626
Ensure that the connect and disconnect functions are called when invoking a reducer from the CLI
2727
"""
@@ -31,3 +31,12 @@ def test_conn_disconn(self):
3131
self.assertIn('_connect called', logs)
3232
self.assertIn('disconnect called', logs)
3333
self.assertIn('Hello, World!', logs)
34+
35+
def test_conn_disconn_sql(self):
36+
"""
37+
Ensure that the connect and disconnect functions are called when invoking a sql from the CLI
38+
"""
39+
self.spacetime("sql", self.database_identity, "select * from st_client")
40+
logs = self.logs(10)
41+
self.assertIn('_connect called', logs)
42+
self.assertIn('disconnect called', logs)

0 commit comments

Comments
 (0)