diff --git a/Cargo.toml b/Cargo.toml index aed5111..df1ab78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,10 @@ tokio = { version = "1.47.1", features = [ "rt-multi-thread" ] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +[dev-dependencies] +ntest = "0.9" +tokio = { version = "1.47.1", features = ["test-util", "macros"] } + [build-dependencies] napi-build = "2.0.1" diff --git a/compat.js b/compat.js index 1bbe841..374a741 100644 --- a/compat.js +++ b/compat.js @@ -28,6 +28,26 @@ function convertError(err) { return err; } +function isQueryOptions(value) { + return value != null + && typeof value === "object" + && !Array.isArray(value) + && Object.prototype.hasOwnProperty.call(value, "queryTimeout"); +} + +function splitBindParameters(bindParameters) { + if (bindParameters.length === 0) { + return { params: undefined, queryOptions: undefined }; + } + if (bindParameters.length > 1 && isQueryOptions(bindParameters[bindParameters.length - 1])) { + return { + params: bindParameters.length === 2 ? bindParameters[0] : bindParameters.slice(0, -1), + queryOptions: bindParameters[bindParameters.length - 1], + }; + } + return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; +} + /** * Database represents a connection that can prepare and execute SQL statements. */ @@ -176,9 +196,9 @@ class Database { * * @param {string} sql - The SQL statement string to execute. */ - exec(sql) { + exec(sql, queryOptions) { try { - databaseExecSync(this.db, sql); + databaseExecSync(this.db, sql, queryOptions); } catch (err) { throw convertError(err); } @@ -263,7 +283,8 @@ class Statement { */ run(...bindParameters) { try { - return statementRunSync(this.stmt, ...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + return statementRunSync(this.stmt, params, queryOptions); } catch (err) { throw convertError(err); } @@ -276,7 +297,8 @@ class Statement { */ get(...bindParameters) { try { - return statementGetSync(this.stmt, ...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + return statementGetSync(this.stmt, params, queryOptions); } catch (err) { throw convertError(err); } @@ -289,7 +311,8 @@ class Statement { */ iterate(...bindParameters) { try { - const it = statementIterateSync(this.stmt, ...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + const it = statementIterateSync(this.stmt, params, queryOptions); return { next: () => iteratorNextSync(it), [Symbol.iterator]() { diff --git a/docs/api.md b/docs/api.md index 33b4a41..46116cb 100644 --- a/docs/api.md +++ b/docs/api.md @@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte - `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds. - `authToken`: authentication token for the provider URL (optional). - `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error +- `defaultQueryTimeout`: default maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error The function returns a `Database` object. @@ -97,13 +98,14 @@ const stmt = db.prepare("SELECT * FROM users"); Loads a SQLite3 extension -### exec(sql) ⇒ this +### exec(sql[, queryOptions]) ⇒ this Executes a SQL statement. | Param | Type | Description | | ------ | ------------------- | ------------------------------------ | | sql | string | The SQL statement string to execute. | +| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | ### interrupt() ⇒ this @@ -119,39 +121,43 @@ Closes the database connection. ## Methods -### run([...bindParameters]) ⇒ object +### run([...bindParameters][, queryOptions]) ⇒ object Executes the SQL statement and returns an info object. | Param | Type | Description | | -------------- | ----------------------------- | ------------------------------------------------ | | bindParameters | array of objects | The bind parameters for executing the statement. | +| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | The returned info object contains two properties: `changes` that describes the number of modified rows and `info.lastInsertRowid` that represents the `rowid` of the last inserted row. -### get([...bindParameters]) ⇒ row +### get([...bindParameters][, queryOptions]) ⇒ row Executes the SQL statement and returns the first row. | Param | Type | Description | | -------------- | ----------------------------- | ------------------------------------------------ | | bindParameters | array of objects | The bind parameters for executing the statement. | +| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | -### all([...bindParameters]) ⇒ array of rows +### all([...bindParameters][, queryOptions]) ⇒ array of rows Executes the SQL statement and returns an array of the resulting rows. | Param | Type | Description | | -------------- | ----------------------------- | ------------------------------------------------ | | bindParameters | array of objects | The bind parameters for executing the statement. | +| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | -### iterate([...bindParameters]) ⇒ iterator +### iterate([...bindParameters][, queryOptions]) ⇒ iterator Executes the SQL statement and returns an iterator to the resulting rows. | Param | Type | Description | | -------------- | ----------------------------- | ------------------------------------------------ | | bindParameters | array of objects | The bind parameters for executing the statement. | +| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | ### pluck([toggleState]) ⇒ this diff --git a/index.d.ts b/index.d.ts index 15b33b4..367e13e 100644 --- a/index.d.ts +++ b/index.d.ts @@ -13,6 +13,11 @@ export interface Options { encryptionCipher?: string encryptionKey?: string remoteEncryptionKey?: string + defaultQueryTimeout?: number +} +/** Per-query execution options. */ +export interface QueryOptions { + queryTimeout?: number } export declare function connect(path: string, opts?: Options | undefined | null): Promise /** Result of a database sync operation. */ @@ -27,12 +32,12 @@ export declare function databasePrepareSync(db: Database, sql: string): Statemen /** Syncs the database in blocking mode. */ export declare function databaseSyncSync(db: Database): SyncResult /** Executes SQL in blocking mode. */ -export declare function databaseExecSync(db: Database, sql: string): void +export declare function databaseExecSync(db: Database, sql: string, queryOptions?: QueryOptions | undefined | null): void /** Gets first row from statement in blocking mode. */ -export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null): unknown +export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): unknown /** Runs a statement in blocking mode. */ -export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null): RunResult -export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null): RowsIterator +export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): RunResult +export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): RowsIterator /** SQLite `run()` result object */ export interface RunResult { changes: number @@ -116,7 +121,7 @@ export declare class Database { * * `env` - The environment. * * `sql` - The SQL statement to execute. */ - exec(sql: string): Promise + exec(sql: string, queryOptions?: QueryOptions | undefined | null): Promise /** * Syncs the database. * @@ -153,7 +158,7 @@ export declare class Statement { * * * `params` - The parameters to bind to the statement. */ - run(params?: unknown | undefined | null): RunResult + run(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object /** * Executes a SQL statement and returns the first row. * @@ -162,7 +167,7 @@ export declare class Statement { * * `env` - The environment. * * `params` - The parameters to bind to the statement. */ - get(params?: unknown | undefined | null): object + get(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object /** * Create an iterator over the rows of a statement. * @@ -171,7 +176,7 @@ export declare class Statement { * * `env` - The environment. * * `params` - The parameters to bind to the statement. */ - iterate(params?: unknown | undefined | null): object + iterate(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object raw(raw?: boolean | undefined | null): this pluck(pluck?: boolean | undefined | null): this timing(timing?: boolean | undefined | null): this diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index 096fde8..cf5a205 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -398,6 +398,56 @@ test.serial("Timeout option", async (t) => { fs.unlinkSync(path); }); +test.serial("Query timeout option interrupts long-running query", async (t) => { + const queryTimeout = 100; + const [db, errorType] = await connect(":memory:", { defaultQueryTimeout: queryTimeout }); + const stmt = await db.prepare( + "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;" + ); + + await t.throwsAsync(async () => { + await stmt.all(); + }, { + instanceOf: errorType, + message: "interrupted", + code: "SQLITE_INTERRUPT", + }); + + db.close(); +}); + +test.serial("Query timeout option allows short-running query", async (t) => { + const [db] = await connect(":memory:", { defaultQueryTimeout: 100 }); + const stmt = await db.prepare("SELECT 1 AS value"); + t.deepEqual(await stmt.get(), { value: 1 }); + db.close(); +}); + +test.serial("Per-query timeout option interrupts long-running Statement.all()", async (t) => { + const [db, errorType] = await connect(":memory:"); + const stmt = await db.prepare( + "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;" + ); + + await t.throwsAsync(async () => { + await stmt.all(undefined, { queryTimeout: 100 }); + }, { + instanceOf: errorType, + message: "interrupted", + code: "SQLITE_INTERRUPT", + }); + + db.close(); +}); + +test.serial("Per-query timeout option is accepted by Database.exec()", async (t) => { + const [db] = await connect(":memory:"); + await db.exec("SELECT 1", { queryTimeout: 100 }); + t.pass(); + + db.close(); +}); + test.serial("Concurrent writes over same connection", async (t) => { const db = t.context.db; await db.exec(` diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js index 5208e8d..8bbf18e 100644 --- a/integration-tests/tests/sync.test.js +++ b/integration-tests/tests/sync.test.js @@ -457,6 +457,75 @@ test.serial("Timeout option", async (t) => { fs.unlinkSync(path); }); +test.serial("Query timeout option interrupts long-running query", async (t) => { + if (t.context.provider === "sqlite") { + t.assert(true); + return; + } + + const [db, errorType] = await connect(":memory:", { defaultQueryTimeout: 100 }); + const stmt = db.prepare( + "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;" + ); + + t.throws(() => { + stmt.all(); + }, { + instanceOf: errorType, + message: "interrupted", + code: "SQLITE_INTERRUPT", + }); + + db.close(); +}); + +test.serial("Query timeout option allows short-running query", async (t) => { + if (t.context.provider === "sqlite") { + t.assert(true); + return; + } + + const [db] = await connect(":memory:", { defaultQueryTimeout: 100 }); + const stmt = db.prepare("SELECT 1 AS value"); + t.deepEqual(stmt.get(), { value: 1 }); + db.close(); +}); + +test.serial("Per-query timeout option interrupts long-running Statement.all()", async (t) => { + if (t.context.provider === "sqlite") { + t.assert(true); + return; + } + + const [db, errorType] = await connect(":memory:"); + const stmt = db.prepare( + "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;" + ); + + t.throws(() => { + stmt.all(undefined, { queryTimeout: 100 }); + }, { + instanceOf: errorType, + message: "interrupted", + code: "SQLITE_INTERRUPT", + }); + + db.close(); +}); + +test.serial("Per-query timeout option is accepted by Database.exec()", async (t) => { + if (t.context.provider === "sqlite") { + t.assert(true); + return; + } + + const [db] = await connect(":memory:"); + db.exec("SELECT 1", { queryTimeout: 100 }); + t.pass(); + + db.close(); +}); + test.serial("Statement.reader [SELECT is true]", async (t) => { const db = t.context.db; const stmt = db.prepare("SELECT * FROM users WHERE id = ?"); diff --git a/promise.js b/promise.js index 8a79863..f364013 100644 --- a/promise.js +++ b/promise.js @@ -41,6 +41,26 @@ function convertError(err) { return err; } +function isQueryOptions(value) { + return value != null + && typeof value === "object" + && !Array.isArray(value) + && Object.prototype.hasOwnProperty.call(value, "queryTimeout"); +} + +function splitBindParameters(bindParameters) { + if (bindParameters.length === 0) { + return { params: undefined, queryOptions: undefined }; + } + if (bindParameters.length > 1 && isQueryOptions(bindParameters[bindParameters.length - 1])) { + return { + params: bindParameters.length === 2 ? bindParameters[0] : bindParameters.slice(0, -1), + queryOptions: bindParameters[bindParameters.length - 1], + }; + } + return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; +} + /** * Creates a new database connection. * @@ -217,9 +237,9 @@ class Database { * * @param {string} sql - The SQL statement string to execute. */ - async exec(sql) { + async exec(sql, queryOptions) { try { - await this.db.exec(sql); + await this.db.exec(sql, queryOptions); } catch (err) { throw convertError(err); } @@ -308,7 +328,8 @@ class Statement { */ async run(...bindParameters) { try { - return await this.stmt.run(...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + return await this.stmt.run(params, queryOptions); } catch (err) { throw convertError(err); } @@ -321,7 +342,8 @@ class Statement { */ async get(...bindParameters) { try { - return await this.stmt.get(...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + return await this.stmt.get(params, queryOptions); } catch (err) { throw convertError(err); } @@ -334,7 +356,8 @@ class Statement { */ async iterate(...bindParameters) { try { - const it = await this.stmt.iterate(...bindParameters); + const { params, queryOptions } = splitBindParameters(bindParameters); + const it = await this.stmt.iterate(params, queryOptions); return { next() { return it.next(); diff --git a/src/lib.rs b/src/lib.rs index 0dbdf9d..61e1581 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ #![allow(deprecated)] mod auth; +mod query_timeout; use napi::{ bindgen_prelude::{Array, FromNapiValue, ToNapiValue}, @@ -28,6 +29,7 @@ use napi::{ }; use napi_derive::napi; use once_cell::sync::OnceCell; +use query_timeout::{QueryTimeoutManager, TimeoutGuard}; use std::{ str::FromStr, sync::{ @@ -200,6 +202,15 @@ pub struct Options { pub encryptionKey: Option, // Encryption key for remote encryption at rest. pub remoteEncryptionKey: Option, + // Default maximum time in milliseconds that a query is allowed to run. + pub defaultQueryTimeout: Option, +} + +/// Per-query execution options. +#[napi(object)] +pub struct QueryOptions { + // Maximum time in milliseconds that this query is allowed to run. + pub queryTimeout: Option, } /// Access mode. @@ -224,10 +235,15 @@ pub struct Database { default_safe_integers: AtomicBool, // Whether to use memory-only mode. memory: bool, + // Maximum time in milliseconds that a query is allowed to run. + query_timeout: Option, + // Shared timeout manager for efficient query timeout handling. + timeout_manager: Arc, } impl Drop for Database { fn drop(&mut self) { + self.timeout_manager.shutdown(); self.conn = None; self.db = None; } @@ -321,11 +337,18 @@ pub async fn connect(path: String, opts: Option) -> Result { conn.busy_timeout(Duration::from_millis(timeout as u64)) .map_err(Error::from)? } + let query_timeout = opts + .as_ref() + .and_then(|o| o.defaultQueryTimeout) + .and_then(query_timeout_duration); + let timeout_manager = Arc::new(QueryTimeoutManager::new()); Ok(Database { db: Some(db), conn: Some(Arc::new(conn)), default_safe_integers, memory, + query_timeout, + timeout_manager, }) } @@ -388,7 +411,13 @@ impl Database { pluck: false.into(), timing: false.into(), }; - Ok(Statement::new(conn, stmt, mode)) + Ok(Statement::new( + conn, + stmt, + mode, + self.query_timeout, + self.timeout_manager.clone(), + )) } /// Sets the authorizer for the database. @@ -509,7 +538,7 @@ impl Database { /// * `env` - The environment. /// * `sql` - The SQL statement to execute. #[napi] - pub async fn exec(&self, sql: String) -> Result<()> { + pub async fn exec(&self, sql: String, query_options: Option) -> Result<()> { let conn = match &self.conn { Some(conn) => conn.clone(), None => { @@ -520,6 +549,11 @@ impl Database { )); } }; + let query_timeout = match query_options.and_then(|o| o.queryTimeout) { + Some(timeout_ms) => query_timeout_duration(timeout_ms), + None => self.query_timeout, + }; + let _guard = query_timeout.map(|t| self.timeout_manager.register(&conn, t)); conn.execute_batch(&sql).await.map_err(Error::from)?; Ok(()) } @@ -566,6 +600,7 @@ impl Database { /// Closes the database connection. #[napi] pub fn close(&mut self) -> Result<()> { + self.timeout_manager.shutdown(); self.conn = None; self.db = None; Ok(()) @@ -609,9 +644,13 @@ pub fn database_sync_sync(db: &Database) -> Result { /// Executes SQL in blocking mode. #[napi] -pub fn database_exec_sync(db: &Database, sql: String) -> Result<()> { +pub fn database_exec_sync( + db: &Database, + sql: String, + query_options: Option, +) -> Result<()> { let rt = runtime()?; - rt.block_on(async move { db.exec(sql).await }) + rt.block_on(async move { db.exec(sql, query_options).await }) } fn is_remote_path(path: &str) -> bool { @@ -625,6 +664,14 @@ fn throw_database_closed_error(env: &Env) -> napi::Error { err } +fn query_timeout_duration(timeout_ms: f64) -> Option { + if timeout_ms.is_finite() && timeout_ms > 0.0 { + Some(Duration::from_millis(timeout_ms as u64)) + } else { + None + } +} + /// SQLite statement object. #[napi] pub struct Statement { @@ -636,6 +683,10 @@ pub struct Statement { column_names: Vec, // The access mode. mode: AccessMode, + // Maximum time in milliseconds that a query is allowed to run. + query_timeout: Option, + // Shared timeout manager. + timeout_manager: Arc, } #[napi] @@ -651,6 +702,8 @@ impl Statement { conn: Arc, stmt: libsql::Statement, mode: AccessMode, + query_timeout: Option, + timeout_manager: Arc, ) -> Self { let column_names: Vec = stmt .columns() @@ -663,6 +716,8 @@ impl Statement { stmt, column_names, mode, + query_timeout, + timeout_manager, } } @@ -672,15 +727,22 @@ impl Statement { /// /// * `params` - The parameters to bind to the statement. #[napi] - pub fn run(&self, env: Env, params: Option) -> Result { + pub fn run( + &self, + env: Env, + params: Option, + query_options: Option, + ) -> Result { self.stmt.reset(); let params = map_params(&self.stmt, params)?; let total_changes_before = self.conn.total_changes(); let start = std::time::Instant::now(); let stmt = self.stmt.clone(); let conn = self.conn.clone(); + let guard = self.start_timeout_guard(query_options); let future = async move { + let _guard = guard; stmt.run(params).await.map_err(Error::from)?; let changes = if conn.total_changes() == total_changes_before { 0 @@ -706,7 +768,12 @@ impl Statement { /// * `env` - The environment. /// * `params` - The parameters to bind to the statement. #[napi] - pub fn get(&self, env: Env, params: Option) -> Result { + pub fn get( + &self, + env: Env, + params: Option, + query_options: Option, + ) -> Result { let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); @@ -723,7 +790,9 @@ impl Statement { }; let stmt_fut = stmt.clone(); + let guard = self.start_timeout_guard(query_options); let future = async move { + let _guard = guard; let mut rows = stmt_fut.query(params).await.map_err(Error::from)?; let row = rows.next().await.map_err(Error::from)?; let duration: Option = start.map(|start| start.elapsed().as_secs_f64()); @@ -779,7 +848,12 @@ impl Statement { /// * `env` - The environment. /// * `params` - The parameters to bind to the statement. #[napi] - pub fn iterate(&self, env: Env, params: Option) -> Result { + pub fn iterate( + &self, + env: Env, + params: Option, + query_options: Option, + ) -> Result { let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); @@ -787,6 +861,7 @@ impl Statement { stmt.reset(); let params = map_params(&stmt, params).unwrap(); let stmt = self.stmt.clone(); + let guard = self.start_timeout_guard(query_options); let future = async move { let rows = stmt.query(params).await.map_err(Error::from)?; Ok::<_, napi::Error>(rows) @@ -799,6 +874,7 @@ impl Statement { safe_ints, raw, pluck, + guard, )) }) } @@ -882,12 +958,27 @@ impl Statement { } } +impl Statement { + fn resolve_query_timeout(&self, query_options: Option) -> Option { + match query_options.and_then(|o| o.queryTimeout) { + Some(timeout_ms) => query_timeout_duration(timeout_ms), + None => self.query_timeout, + } + } + + fn start_timeout_guard(&self, query_options: Option) -> Option { + self.resolve_query_timeout(query_options) + .map(|t| self.timeout_manager.register(&self.conn, t)) + } +} + /// Gets first row from statement in blocking mode. #[napi] pub fn statement_get_sync( stmt: &Statement, env: Env, params: Option, + query_options: Option, ) -> Result { let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); @@ -901,6 +992,7 @@ pub fn statement_get_sync( }; let rt = runtime()?; + let _guard = stmt.start_timeout_guard(query_options); rt.block_on(async move { let params = map_params(&stmt.stmt, params)?; let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?; @@ -922,9 +1014,14 @@ pub fn statement_get_sync( /// Runs a statement in blocking mode. #[napi] -pub fn statement_run_sync(stmt: &Statement, params: Option) -> Result { +pub fn statement_run_sync( + stmt: &Statement, + params: Option, + query_options: Option, +) -> Result { stmt.stmt.reset(); let rt = runtime()?; + let _guard = stmt.start_timeout_guard(query_options); rt.block_on(async move { let params = map_params(&stmt.stmt, params)?; let total_changes_before = stmt.conn.total_changes(); @@ -951,16 +1048,18 @@ pub fn statement_iterate_sync( stmt: &Statement, _env: Env, params: Option, + query_options: Option, ) -> Result { let rt = runtime()?; let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); let pluck = stmt.mode.pluck.load(Ordering::SeqCst); - let stmt = stmt.stmt.clone(); + let guard = stmt.start_timeout_guard(query_options); + let inner_stmt = stmt.stmt.clone(); let (rows, column_names) = rt.block_on(async move { - stmt.reset(); - let params = map_params(&stmt, params)?; - let rows = stmt.query(params).await.map_err(Error::from)?; + inner_stmt.reset(); + let params = map_params(&inner_stmt, params)?; + let rows = inner_stmt.query(params).await.map_err(Error::from)?; let mut column_names = Vec::new(); for i in 0..rows.column_count() { column_names @@ -974,6 +1073,7 @@ pub fn statement_iterate_sync( safe_ints, raw, pluck, + guard, )) } @@ -1120,6 +1220,7 @@ pub struct RowsIterator { safe_ints: bool, raw: bool, pluck: bool, + _timeout_guard: Option, } #[napi] @@ -1130,6 +1231,7 @@ impl RowsIterator { safe_ints: bool, raw: bool, pluck: bool, + timeout_guard: Option, ) -> Self { Self { rows, @@ -1137,6 +1239,7 @@ impl RowsIterator { safe_ints, raw, pluck, + _timeout_guard: timeout_guard, } } diff --git a/src/query_timeout.rs b/src/query_timeout.rs new file mode 100644 index 0000000..7ed504a --- /dev/null +++ b/src/query_timeout.rs @@ -0,0 +1,347 @@ +use std::{ + cmp::Reverse, + collections::BinaryHeap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, Weak, + }, + time::Duration, +}; +use tokio::{sync::Notify, time::Instant}; + +/// A single-background-task timer wheel that interrupts connections when their +/// query deadline expires. Registering a query returns a [`TimeoutGuard`] — +/// dropping the guard cancels the timeout. +pub struct QueryTimeoutManager { + inner: Arc, +} + +struct Inner { + entries: Mutex, + /// Wakes the background task when the earliest deadline changes. + notify: Arc, + /// Set to `true` to make the background task exit. + shutdown: AtomicBool, +} + +struct Entries { + heap: BinaryHeap>, + next_id: u64, +} + +#[derive(Clone)] +struct Entry { + id: u64, + deadline: Instant, + conn: Weak, + /// Cleared when the guard is dropped (query finished in time). + active: Arc, +} + +impl PartialEq for Entry { + fn eq(&self, other: &Self) -> bool { + self.deadline == other.deadline && self.id == other.id + } +} +impl Eq for Entry {} +impl PartialOrd for Entry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for Entry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.deadline + .cmp(&other.deadline) + .then(self.id.cmp(&other.id)) + } +} + +impl QueryTimeoutManager { + pub fn new() -> Self { + let inner = Arc::new(Inner { + entries: Mutex::new(Entries { + heap: BinaryHeap::new(), + next_id: 0, + }), + notify: Arc::new(Notify::new()), + shutdown: AtomicBool::new(false), + }); + let bg = Arc::downgrade(&inner); + tokio::spawn(async move { + Self::background_task(bg).await; + }); + Self { inner } + } + + /// Synchronously remove all entries from the heap, releasing any + /// connection references the background task is holding. + pub fn clear(&self) { + let mut entries = self.inner.entries.lock().unwrap(); + entries.heap.clear(); + } + + /// Signal the background task to exit and clear all entries. + pub fn shutdown(&self) { + self.inner.shutdown.store(true, Ordering::Relaxed); + self.clear(); + self.inner.notify.notify_one(); + } + + /// Register a query. The returned guard must be held for the duration of + /// the query — dropping it cancels the timeout. + pub fn register(&self, conn: &Arc, timeout: Duration) -> TimeoutGuard { + let active = Arc::new(AtomicBool::new(true)); + let mut entries = self.inner.entries.lock().unwrap(); + let id = entries.next_id; + entries.next_id += 1; + let deadline = Instant::now() + .checked_add(timeout) + .unwrap_or_else(|| Instant::now() + Duration::from_secs(86400)); + let entry = Entry { + id, + deadline, + conn: Arc::downgrade(conn), + active: active.clone(), + }; + let is_new_earliest = entries + .heap + .peek() + .map_or(true, |Reverse(e)| entry.deadline < e.deadline); + entries.heap.push(Reverse(entry)); + drop(entries); + if is_new_earliest { + self.inner.notify.notify_one(); + } + TimeoutGuard { + active, + notify: self.inner.notify.clone(), + } + } + + async fn background_task(weak: Weak) { + loop { + let inner = match weak.upgrade() { + Some(inner) => inner, + None => return, // Manager dropped — exit. + }; + + if inner.shutdown.load(Ordering::Relaxed) { + return; + } + + // Find the next deadline, skipping cancelled entries. + let next = { + let mut entries = inner.entries.lock().unwrap(); + loop { + match entries.heap.peek() { + Some(Reverse(e)) if !e.active.load(Ordering::Relaxed) => { + entries.heap.pop(); + } + Some(Reverse(e)) => break Some(e.clone()), + None => break None, + } + } + }; + + match next { + Some(entry) => { + tokio::select! { + _ = tokio::time::sleep_until(entry.deadline) => { + // Deadline reached — interrupt if still active. + if entry.active.load(Ordering::Relaxed) { + if let Some(conn) = entry.conn.upgrade() { + let _ = conn.interrupt(); + } + } + // Remove this entry. + let mut entries = inner.entries.lock().unwrap(); + // Pop entries that are done (expired or cancelled). + while let Some(Reverse(e)) = entries.heap.peek() { + if !e.active.load(Ordering::Relaxed) || e.id == entry.id { + entries.heap.pop(); + } else { + break; + } + } + } + _ = inner.notify.notified() => { + // A new earlier deadline was added; re-check. + } + } + } + None => { + // Nothing to do — wait until a new entry is registered. + // Must hold the Arc while waiting so we can detect drop next iteration. + inner.notify.notified().await; + } + } + } + } +} + +impl Drop for QueryTimeoutManager { + fn drop(&mut self) { + // Signal the background task to exit. + self.inner.shutdown.store(true, Ordering::Relaxed); + self.inner.notify.notify_one(); + } +} + +/// Dropping this guard cancels the associated query timeout. +pub struct TimeoutGuard { + active: Arc, + notify: Arc, +} + +impl Drop for TimeoutGuard { + fn drop(&mut self) { + self.active.store(false, Ordering::Relaxed); + // Wake the background task so it can clean up the cancelled entry. + self.notify.notify_one(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + async fn test_conn() -> Arc { + let db = libsql::Builder::new_local(":memory:") + .build() + .await + .unwrap(); + Arc::new(db.connect().unwrap()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn deadline_expires_interrupts_connection() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + + // Register a 200ms timeout, then start an infinite query. + let _guard = mgr.register(&conn, Duration::from_millis(200)); + let fut = { + let conn = conn.clone(); + tokio::spawn(async move { + conn.execute_batch( + "WITH RECURSIVE r(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM r) SELECT * FROM r", + ) + .await + }) + }; + + // The query should have been interrupted by the timeout. + let result = fut.await.unwrap(); + assert!(result.is_err(), "query should have been interrupted"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn guard_dropped_before_deadline_cancels_timeout() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + + let guard = mgr.register(&conn, Duration::from_millis(200)); + + // Query "finishes" before the deadline. + drop(guard); + + // Wait past where the deadline would have been. + tokio::time::sleep(Duration::from_millis(300)).await; + + // Connection should still work — no spurious interrupt. + let result = conn.execute_batch("SELECT 1").await; + assert!( + result.is_ok(), + "connection should not have been interrupted" + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn guard_drop_cleans_entry_from_heap() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + + let guard = mgr.register(&conn, Duration::from_millis(500)); + + drop(guard); + // Let the background task wake up and clean the cancelled entry. + tokio::time::sleep(Duration::from_millis(50)).await; + + let entries = mgr.inner.entries.lock().unwrap(); + assert_eq!( + entries.heap.len(), + 0, + "dropping guard should clean up the entry from the heap" + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn multiple_deadlines_fire_in_order() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + + // Register three timeouts at 100ms, 200ms, 300ms. + let _g1 = mgr.register(&conn, Duration::from_millis(100)); + let _g2 = mgr.register(&conn, Duration::from_millis(200)); + let _g3 = mgr.register(&conn, Duration::from_millis(300)); + + // After 150ms, only the first should have fired. + tokio::time::sleep(Duration::from_millis(150)).await; + { + let entries = mgr.inner.entries.lock().unwrap(); + assert_eq!(entries.heap.len(), 2, "only first entry should have fired"); + } + + // After 350ms total, all three should have fired. + tokio::time::sleep(Duration::from_millis(200)).await; + { + let entries = mgr.inner.entries.lock().unwrap(); + assert_eq!(entries.heap.len(), 0, "all entries should be cleaned up"); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn new_earlier_deadline_preempts_existing() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + + // Register a long timeout first, then a shorter one that should preempt. + let _g1 = mgr.register(&conn, Duration::from_millis(5000)); + let _g2 = mgr.register(&conn, Duration::from_millis(200)); + + let fut = { + let conn = conn.clone(); + tokio::spawn(async move { + conn.execute_batch( + "WITH RECURSIVE r(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM r) SELECT * FROM r", + ) + .await + }) + }; + + let result = fut.await.unwrap(); + assert!( + result.is_err(), + "shorter deadline should have interrupted the query" + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ntest::timeout(10000)] + async fn background_task_exits_when_manager_dropped() { + let conn = test_conn().await; + let mgr = QueryTimeoutManager::new(); + let guard = mgr.register(&conn, Duration::from_millis(5000)); + drop(guard); + drop(mgr); + // If the background task didn't exit, it would leak — verify no panic. + tokio::time::sleep(Duration::from_millis(50)).await; + } +}