Skip to content

Commit 09bf7a7

Browse files
committed
fix(sqlite): keep shutdown database closed
1 parent 9b6024c commit 09bf7a7

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

  • rivetkit-rust/packages/rivetkit-core/src/actor

rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use std::collections::HashSet;
22
use std::io::Cursor;
3-
#[cfg(feature = "sqlite-local")]
4-
use std::sync::Arc;
3+
use std::sync::{
4+
Arc,
5+
atomic::{AtomicBool, Ordering},
6+
};
57
#[cfg(feature = "sqlite-local")]
68
use std::time::Duration;
79

@@ -79,6 +81,7 @@ pub struct SqliteDb {
7981
/// always sets up sqlite storage under the hood, so handle/actor_id are
8082
/// not a reliable signal for whether the user opted in; this flag is.
8183
enabled: bool,
84+
shutdown_closed: Arc<AtomicBool>,
8285
#[cfg(feature = "sqlite-local")]
8386
// Forced-sync: native SQLite handles are used inside spawn_blocking and
8487
// synchronous diagnostic accessors.
@@ -115,6 +118,7 @@ impl SqliteDb {
115118
generation,
116119
backend: select_sqlite_backend(enabled, remote_sqlite),
117120
enabled,
121+
shutdown_closed: Default::default(),
118122
#[cfg(feature = "sqlite-local")]
119123
db: Default::default(),
120124
#[cfg(feature = "sqlite-local")]
@@ -143,17 +147,20 @@ impl SqliteDb {
143147
&self,
144148
request: protocol::SqliteGetPagesRequest,
145149
) -> Result<protocol::SqliteGetPagesResponse> {
150+
self.ensure_not_shutdown_closed()?;
146151
self.handle()?.sqlite_get_pages(request).await
147152
}
148153

149154
pub async fn commit(
150155
&self,
151156
request: protocol::SqliteCommitRequest,
152157
) -> Result<protocol::SqliteCommitResponse> {
158+
self.ensure_not_shutdown_closed()?;
153159
self.handle()?.sqlite_commit(request).await
154160
}
155161

156162
pub async fn open(&self) -> Result<()> {
163+
self.ensure_not_shutdown_closed()?;
157164
match self.backend {
158165
SqliteBackend::LocalNative => {
159166
#[cfg(feature = "sqlite-local")]
@@ -272,6 +279,7 @@ impl SqliteDb {
272279
}
273280

274281
pub async fn exec(&self, sql: impl Into<String>) -> Result<QueryResult> {
282+
self.ensure_not_shutdown_closed()?;
275283
let sql = sql.into();
276284
match self.backend {
277285
SqliteBackend::LocalNative => self.local_exec(sql).await,
@@ -285,6 +293,7 @@ impl SqliteDb {
285293
sql: impl Into<String>,
286294
params: Option<Vec<BindParam>>,
287295
) -> Result<QueryResult> {
296+
self.ensure_not_shutdown_closed()?;
288297
let sql = sql.into();
289298
match self.backend {
290299
SqliteBackend::LocalNative => self.local_query(sql, params).await,
@@ -300,6 +309,7 @@ impl SqliteDb {
300309
sql: impl Into<String>,
301310
params: Option<Vec<BindParam>>,
302311
) -> Result<ExecResult> {
312+
self.ensure_not_shutdown_closed()?;
303313
let sql = sql.into();
304314
match self.backend {
305315
SqliteBackend::LocalNative => self.local_run(sql, params).await,
@@ -315,6 +325,7 @@ impl SqliteDb {
315325
sql: impl Into<String>,
316326
params: Option<Vec<BindParam>>,
317327
) -> Result<ExecuteResult> {
328+
self.ensure_not_shutdown_closed()?;
318329
let sql = sql.into();
319330
match self.backend {
320331
SqliteBackend::LocalNative => self.local_execute(sql, params).await,
@@ -328,6 +339,7 @@ impl SqliteDb {
328339
sql: impl Into<String>,
329340
params: Option<Vec<BindParam>>,
330341
) -> Result<ExecuteResult> {
342+
self.ensure_not_shutdown_closed()?;
331343
let sql = sql.into();
332344
match self.backend {
333345
SqliteBackend::LocalNative => self.local_execute_write(sql, params).await,
@@ -354,6 +366,7 @@ impl SqliteDb {
354366
}
355367

356368
pub(crate) async fn cleanup(&self) -> Result<()> {
369+
self.shutdown_closed.store(true, Ordering::SeqCst);
357370
#[cfg(feature = "sqlite-local")]
358371
{
359372
self.stop_preload_hint_flush_task();
@@ -362,6 +375,13 @@ impl SqliteDb {
362375
self.close().await
363376
}
364377

378+
fn ensure_not_shutdown_closed(&self) -> Result<()> {
379+
if self.shutdown_closed.load(Ordering::SeqCst) {
380+
return Err(SqliteRuntimeError::Closed.build());
381+
}
382+
Ok(())
383+
}
384+
365385
#[cfg(feature = "sqlite-local")]
366386
fn ensure_preload_hint_flush_task(&self) -> Result<()> {
367387
if !sqlite_optimization_flags().preload_hint_flush {

0 commit comments

Comments
 (0)