Skip to content

Commit 9bf8eb7

Browse files
committed
fix(sqlite): reject work after shutdown close
1 parent 977011c commit 9bf8eb7

1 file changed

Lines changed: 29 additions & 2 deletions

File tree

  • rivetkit-rust/packages/rivetkit-core/src/actor

rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use std::collections::HashSet;
22
use std::io::Cursor;
3-
#[cfg(feature = "sqlite-local")]
4-
use std::sync::Arc;
3+
use std::sync::{
4+
Arc,
5+
atomic::{AtomicBool, Ordering},
6+
};
57
#[cfg(feature = "sqlite-local")]
68
use std::time::Duration;
79

@@ -79,6 +81,7 @@ pub struct SqliteDb {
7981
/// always sets up sqlite storage under the hood, so handle/actor_id are
8082
/// not a reliable signal for whether the user opted in; this flag is.
8183
enabled: bool,
84+
shutdown_closed: Arc<AtomicBool>,
8285
#[cfg(feature = "sqlite-local")]
8386
// Forced-sync: native SQLite handles are used inside spawn_blocking and
8487
// synchronous diagnostic accessors.
@@ -115,6 +118,7 @@ impl SqliteDb {
115118
generation,
116119
backend: select_sqlite_backend(enabled, remote_sqlite),
117120
enabled,
121+
shutdown_closed: Default::default(),
118122
#[cfg(feature = "sqlite-local")]
119123
db: Default::default(),
120124
#[cfg(feature = "sqlite-local")]
@@ -143,22 +147,26 @@ impl SqliteDb {
143147
&self,
144148
request: protocol::SqliteGetPagesRequest,
145149
) -> Result<protocol::SqliteGetPagesResponse> {
150+
self.ensure_not_shutdown_closed()?;
146151
self.handle()?.sqlite_get_pages(request).await
147152
}
148153

149154
pub async fn commit(
150155
&self,
151156
request: protocol::SqliteCommitRequest,
152157
) -> Result<protocol::SqliteCommitResponse> {
158+
self.ensure_not_shutdown_closed()?;
153159
self.handle()?.sqlite_commit(request).await
154160
}
155161

156162
pub async fn open(&self) -> Result<()> {
163+
self.ensure_not_shutdown_closed()?;
157164
match self.backend {
158165
SqliteBackend::LocalNative => {
159166
#[cfg(feature = "sqlite-local")]
160167
{
161168
let _open_guard = self.open_lock.lock().await;
169+
self.ensure_not_shutdown_closed()?;
162170
if self.db.lock().is_some() {
163171
return Ok(());
164172
}
@@ -178,6 +186,10 @@ impl SqliteDb {
178186
vfs_metrics,
179187
)
180188
.await?;
189+
if self.shutdown_closed.load(Ordering::SeqCst) {
190+
native_db.close().await?;
191+
return Err(SqliteRuntimeError::Closed.build());
192+
}
181193
*self.db.lock() = Some(native_db);
182194
self.ensure_preload_hint_flush_task()?;
183195
Ok(())
@@ -272,6 +284,7 @@ impl SqliteDb {
272284
}
273285

274286
pub async fn exec(&self, sql: impl Into<String>) -> Result<QueryResult> {
287+
self.ensure_not_shutdown_closed()?;
275288
let sql = sql.into();
276289
match self.backend {
277290
SqliteBackend::LocalNative => self.local_exec(sql).await,
@@ -285,6 +298,7 @@ impl SqliteDb {
285298
sql: impl Into<String>,
286299
params: Option<Vec<BindParam>>,
287300
) -> Result<QueryResult> {
301+
self.ensure_not_shutdown_closed()?;
288302
let sql = sql.into();
289303
match self.backend {
290304
SqliteBackend::LocalNative => self.local_query(sql, params).await,
@@ -300,6 +314,7 @@ impl SqliteDb {
300314
sql: impl Into<String>,
301315
params: Option<Vec<BindParam>>,
302316
) -> Result<ExecResult> {
317+
self.ensure_not_shutdown_closed()?;
303318
let sql = sql.into();
304319
match self.backend {
305320
SqliteBackend::LocalNative => self.local_run(sql, params).await,
@@ -315,6 +330,7 @@ impl SqliteDb {
315330
sql: impl Into<String>,
316331
params: Option<Vec<BindParam>>,
317332
) -> Result<ExecuteResult> {
333+
self.ensure_not_shutdown_closed()?;
318334
let sql = sql.into();
319335
match self.backend {
320336
SqliteBackend::LocalNative => self.local_execute(sql, params).await,
@@ -328,6 +344,7 @@ impl SqliteDb {
328344
sql: impl Into<String>,
329345
params: Option<Vec<BindParam>>,
330346
) -> Result<ExecuteResult> {
347+
self.ensure_not_shutdown_closed()?;
331348
let sql = sql.into();
332349
match self.backend {
333350
SqliteBackend::LocalNative => self.local_execute_write(sql, params).await,
@@ -354,6 +371,9 @@ impl SqliteDb {
354371
}
355372

356373
pub(crate) async fn cleanup(&self) -> Result<()> {
374+
self.shutdown_closed.store(true, Ordering::SeqCst);
375+
#[cfg(feature = "sqlite-local")]
376+
let _open_guard = self.open_lock.lock().await;
357377
#[cfg(feature = "sqlite-local")]
358378
{
359379
self.stop_preload_hint_flush_task();
@@ -362,6 +382,13 @@ impl SqliteDb {
362382
self.close().await
363383
}
364384

385+
fn ensure_not_shutdown_closed(&self) -> Result<()> {
386+
if self.shutdown_closed.load(Ordering::SeqCst) {
387+
return Err(SqliteRuntimeError::Closed.build());
388+
}
389+
Ok(())
390+
}
391+
365392
#[cfg(feature = "sqlite-local")]
366393
fn ensure_preload_hint_flush_task(&self) -> Result<()> {
367394
if !sqlite_optimization_flags().preload_hint_flush {

0 commit comments

Comments
 (0)