Skip to content

Commit 3a8d16e

Browse files
committed
fix(rivetkit): prevent sqlite access after shtudown
1 parent 30bb0dd commit 3a8d16e

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

  • rivetkit-rust/packages/rivetkit-core/src/actor

rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::collections::HashSet;
22
use std::io::Cursor;
33
#[cfg(feature = "sqlite")]
4-
use std::sync::Arc;
4+
use std::sync::{
5+
Arc,
6+
atomic::{AtomicBool, Ordering},
7+
};
58

69
use anyhow::{Context, Result};
710
#[cfg(feature = "sqlite")]
@@ -84,6 +87,8 @@ pub struct SqliteDb {
8487
// Forced-sync: native SQLite handles are used inside spawn_blocking and
8588
// synchronous diagnostic accessors.
8689
db: Arc<Mutex<Option<NativeDatabaseHandle>>>,
90+
#[cfg(feature = "sqlite")]
91+
cleaned_up: Arc<AtomicBool>,
8792
}
8893

8994
impl SqliteDb {
@@ -98,6 +103,8 @@ impl SqliteDb {
98103
enabled,
99104
#[cfg(feature = "sqlite")]
100105
db: Default::default(),
106+
#[cfg(feature = "sqlite")]
107+
cleaned_up: Default::default(),
101108
}
102109
}
103110

@@ -122,13 +129,23 @@ impl SqliteDb {
122129
pub async fn open(&self) -> Result<()> {
123130
#[cfg(feature = "sqlite")]
124131
{
132+
self.ensure_not_cleaned_up()?;
133+
125134
let config = self.runtime_config()?;
126135
let db = self.db.clone();
136+
let cleaned_up = self.cleaned_up.clone();
127137
let rt_handle = tokio::runtime::Handle::try_current()
128138
.context("open sqlite database requires a tokio runtime")?;
129139

130140
tokio::task::spawn_blocking(move || {
141+
if cleaned_up.load(Ordering::Acquire) {
142+
return Err(SqliteRuntimeError::Closed.build());
143+
}
144+
131145
let mut guard = db.lock();
146+
if cleaned_up.load(Ordering::Acquire) {
147+
return Err(SqliteRuntimeError::Closed.build());
148+
}
132149
if guard.is_some() {
133150
return Ok(());
134151
}
@@ -251,6 +268,9 @@ impl SqliteDb {
251268
}
252269

253270
pub(crate) async fn cleanup(&self) -> Result<()> {
271+
#[cfg(feature = "sqlite")]
272+
self.cleaned_up.store(true, Ordering::Release);
273+
254274
self.close().await
255275
}
256276

@@ -319,6 +339,15 @@ impl SqliteDb {
319339
.clone()
320340
.ok_or_else(|| sqlite_not_configured("handle"))
321341
}
342+
343+
#[cfg(feature = "sqlite")]
344+
fn ensure_not_cleaned_up(&self) -> Result<()> {
345+
if self.cleaned_up.load(Ordering::Acquire) {
346+
return Err(SqliteRuntimeError::Closed.build());
347+
}
348+
349+
Ok(())
350+
}
322351
}
323352

324353
impl std::fmt::Debug for SqliteDb {

0 commit comments

Comments
 (0)