|
| 1 | +use rusqlite::Connection; |
| 2 | +use std::sync::{Arc, Mutex}; |
| 3 | +use tokio::task; |
| 4 | + |
| 5 | +/// A connection manager for SQLite that wraps rusqlite's synchronous Connection |
| 6 | +/// behind an Arc<Mutex<>> for thread-safe access with bb8 connection pooling. |
| 7 | +#[derive(Clone, Debug)] |
| 8 | +pub struct SqliteConnectionManager { |
| 9 | + db_path: String, |
| 10 | + connection_name: String, |
| 11 | +} |
| 12 | + |
| 13 | +/// Wrapper around rusqlite::Connection to make it Send + Sync for bb8 |
| 14 | +pub struct SqliteConnection { |
| 15 | + pub conn: Arc<Mutex<Connection>>, |
| 16 | +} |
| 17 | + |
| 18 | +// Safety: rusqlite::Connection is not Send by default, but we protect it with Mutex |
| 19 | +// and only access it via spawn_blocking |
| 20 | +unsafe impl Send for SqliteConnection {} |
| 21 | +unsafe impl Sync for SqliteConnection {} |
| 22 | + |
| 23 | +impl SqliteConnectionManager { |
| 24 | + pub fn new(db_path: String, connection_name: String) -> Self { |
| 25 | + Self { |
| 26 | + db_path, |
| 27 | + connection_name, |
| 28 | + } |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +#[derive(Debug)] |
| 33 | +pub struct SqlitePoolError(pub String); |
| 34 | + |
| 35 | +impl std::fmt::Display for SqlitePoolError { |
| 36 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 37 | + write!(f, "SQLite pool error: {}", self.0) |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +impl std::error::Error for SqlitePoolError {} |
| 42 | + |
| 43 | +impl bb8::ManageConnection for SqliteConnectionManager { |
| 44 | + type Connection = SqliteConnection; |
| 45 | + type Error = SqlitePoolError; |
| 46 | + |
| 47 | + async fn connect(&self) -> Result<Self::Connection, Self::Error> { |
| 48 | + let db_path = self.db_path.clone(); |
| 49 | + let connection_name = self.connection_name.clone(); |
| 50 | + |
| 51 | + let conn = task::spawn_blocking(move || { |
| 52 | + Connection::open(&db_path).unwrap_or_else(|err| { |
| 53 | + panic!( |
| 54 | + "Failed to open SQLite database at '{}' for connection '{}': {}", |
| 55 | + db_path, connection_name, err |
| 56 | + ) |
| 57 | + }) |
| 58 | + }) |
| 59 | + .await |
| 60 | + .map_err(|e| SqlitePoolError(format!("Failed to spawn blocking task: {e}")))?; |
| 61 | + |
| 62 | + Ok(SqliteConnection { |
| 63 | + conn: Arc::new(Mutex::new(conn)), |
| 64 | + }) |
| 65 | + } |
| 66 | + |
| 67 | + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { |
| 68 | + let inner = conn.conn.clone(); |
| 69 | + task::spawn_blocking(move || { |
| 70 | + let conn = inner.lock().unwrap(); |
| 71 | + conn |
| 72 | + .execute_batch("SELECT 1") |
| 73 | + .map_err(|e| SqlitePoolError(format!("SQLite connection validation failed: {e}"))) |
| 74 | + }) |
| 75 | + .await |
| 76 | + .map_err(|e| SqlitePoolError(format!("Failed to spawn blocking task: {e}")))? |
| 77 | + } |
| 78 | + |
| 79 | + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { |
| 80 | + false |
| 81 | + } |
| 82 | +} |
0 commit comments