diff --git a/Cargo.toml b/Cargo.toml
index aed5111..df1ab78 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -19,6 +19,10 @@ tokio = { version = "1.47.1", features = [ "rt-multi-thread" ] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
+[dev-dependencies]
+ntest = "0.9"
+tokio = { version = "1.47.1", features = ["test-util", "macros"] }
+
[build-dependencies]
napi-build = "2.0.1"
diff --git a/compat.js b/compat.js
index 1bbe841..374a741 100644
--- a/compat.js
+++ b/compat.js
@@ -28,6 +28,26 @@ function convertError(err) {
return err;
}
+function isQueryOptions(value) {
+ return value != null
+ && typeof value === "object"
+ && !Array.isArray(value)
+ && Object.prototype.hasOwnProperty.call(value, "queryTimeout");
+}
+
+function splitBindParameters(bindParameters) {
+ if (bindParameters.length === 0) {
+ return { params: undefined, queryOptions: undefined };
+ }
+ if (bindParameters.length > 1 && isQueryOptions(bindParameters[bindParameters.length - 1])) {
+ return {
+ params: bindParameters.length === 2 ? bindParameters[0] : bindParameters.slice(0, -1),
+ queryOptions: bindParameters[bindParameters.length - 1],
+ };
+ }
+ return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined };
+}
+
/**
* Database represents a connection that can prepare and execute SQL statements.
*/
@@ -176,9 +196,9 @@ class Database {
*
* @param {string} sql - The SQL statement string to execute.
*/
- exec(sql) {
+ exec(sql, queryOptions) {
try {
- databaseExecSync(this.db, sql);
+ databaseExecSync(this.db, sql, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -263,7 +283,8 @@ class Statement {
*/
run(...bindParameters) {
try {
- return statementRunSync(this.stmt, ...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ return statementRunSync(this.stmt, params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -276,7 +297,8 @@ class Statement {
*/
get(...bindParameters) {
try {
- return statementGetSync(this.stmt, ...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ return statementGetSync(this.stmt, params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -289,7 +311,8 @@ class Statement {
*/
iterate(...bindParameters) {
try {
- const it = statementIterateSync(this.stmt, ...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ const it = statementIterateSync(this.stmt, params, queryOptions);
return {
next: () => iteratorNextSync(it),
[Symbol.iterator]() {
diff --git a/docs/api.md b/docs/api.md
index 33b4a41..46116cb 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte
- `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds.
- `authToken`: authentication token for the provider URL (optional).
- `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error
+- `defaultQueryTimeout`: default maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error
The function returns a `Database` object.
@@ -97,13 +98,14 @@ const stmt = db.prepare("SELECT * FROM users");
Loads a SQLite3 extension
-### exec(sql) ⇒ this
+### exec(sql[, queryOptions]) ⇒ this
Executes a SQL statement.
| Param | Type | Description |
| ------ | ------------------- | ------------------------------------ |
| sql | string | The SQL statement string to execute. |
+| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
### interrupt() ⇒ this
@@ -119,39 +121,43 @@ Closes the database connection.
## Methods
-### run([...bindParameters]) ⇒ object
+### run([...bindParameters][, queryOptions]) ⇒ object
Executes the SQL statement and returns an info object.
| Param | Type | Description |
| -------------- | ----------------------------- | ------------------------------------------------ |
| bindParameters | array of objects | The bind parameters for executing the statement. |
+| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
The returned info object contains two properties: `changes` that describes the number of modified rows and `info.lastInsertRowid` that represents the `rowid` of the last inserted row.
-### get([...bindParameters]) ⇒ row
+### get([...bindParameters][, queryOptions]) ⇒ row
Executes the SQL statement and returns the first row.
| Param | Type | Description |
| -------------- | ----------------------------- | ------------------------------------------------ |
| bindParameters | array of objects | The bind parameters for executing the statement. |
+| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
-### all([...bindParameters]) ⇒ array of rows
+### all([...bindParameters][, queryOptions]) ⇒ array of rows
Executes the SQL statement and returns an array of the resulting rows.
| Param | Type | Description |
| -------------- | ----------------------------- | ------------------------------------------------ |
| bindParameters | array of objects | The bind parameters for executing the statement. |
+| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
-### iterate([...bindParameters]) ⇒ iterator
+### iterate([...bindParameters][, queryOptions]) ⇒ iterator
Executes the SQL statement and returns an iterator to the resulting rows.
| Param | Type | Description |
| -------------- | ----------------------------- | ------------------------------------------------ |
| bindParameters | array of objects | The bind parameters for executing the statement. |
+| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
### pluck([toggleState]) ⇒ this
diff --git a/index.d.ts b/index.d.ts
index 15b33b4..367e13e 100644
--- a/index.d.ts
+++ b/index.d.ts
@@ -13,6 +13,11 @@ export interface Options {
encryptionCipher?: string
encryptionKey?: string
remoteEncryptionKey?: string
+ defaultQueryTimeout?: number
+}
+/** Per-query execution options. */
+export interface QueryOptions {
+ queryTimeout?: number
}
export declare function connect(path: string, opts?: Options | undefined | null): Promise
/** Result of a database sync operation. */
@@ -27,12 +32,12 @@ export declare function databasePrepareSync(db: Database, sql: string): Statemen
/** Syncs the database in blocking mode. */
export declare function databaseSyncSync(db: Database): SyncResult
/** Executes SQL in blocking mode. */
-export declare function databaseExecSync(db: Database, sql: string): void
+export declare function databaseExecSync(db: Database, sql: string, queryOptions?: QueryOptions | undefined | null): void
/** Gets first row from statement in blocking mode. */
-export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null): unknown
+export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): unknown
/** Runs a statement in blocking mode. */
-export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null): RunResult
-export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null): RowsIterator
+export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): RunResult
+export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): RowsIterator
/** SQLite `run()` result object */
export interface RunResult {
changes: number
@@ -116,7 +121,7 @@ export declare class Database {
* * `env` - The environment.
* * `sql` - The SQL statement to execute.
*/
- exec(sql: string): Promise
+ exec(sql: string, queryOptions?: QueryOptions | undefined | null): Promise
/**
* Syncs the database.
*
@@ -153,7 +158,7 @@ export declare class Statement {
*
* * `params` - The parameters to bind to the statement.
*/
- run(params?: unknown | undefined | null): RunResult
+ run(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object
/**
* Executes a SQL statement and returns the first row.
*
@@ -162,7 +167,7 @@ export declare class Statement {
* * `env` - The environment.
* * `params` - The parameters to bind to the statement.
*/
- get(params?: unknown | undefined | null): object
+ get(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object
/**
* Create an iterator over the rows of a statement.
*
@@ -171,7 +176,7 @@ export declare class Statement {
* * `env` - The environment.
* * `params` - The parameters to bind to the statement.
*/
- iterate(params?: unknown | undefined | null): object
+ iterate(params?: unknown | undefined | null, queryOptions?: QueryOptions | undefined | null): object
raw(raw?: boolean | undefined | null): this
pluck(pluck?: boolean | undefined | null): this
timing(timing?: boolean | undefined | null): this
diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js
index 096fde8..cf5a205 100644
--- a/integration-tests/tests/async.test.js
+++ b/integration-tests/tests/async.test.js
@@ -398,6 +398,56 @@ test.serial("Timeout option", async (t) => {
fs.unlinkSync(path);
});
+test.serial("Query timeout option interrupts long-running query", async (t) => {
+ const queryTimeout = 100;
+ const [db, errorType] = await connect(":memory:", { defaultQueryTimeout: queryTimeout });
+ const stmt = await db.prepare(
+ "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
+ );
+
+ await t.throwsAsync(async () => {
+ await stmt.all();
+ }, {
+ instanceOf: errorType,
+ message: "interrupted",
+ code: "SQLITE_INTERRUPT",
+ });
+
+ db.close();
+});
+
+test.serial("Query timeout option allows short-running query", async (t) => {
+ const [db] = await connect(":memory:", { defaultQueryTimeout: 100 });
+ const stmt = await db.prepare("SELECT 1 AS value");
+ t.deepEqual(await stmt.get(), { value: 1 });
+ db.close();
+});
+
+test.serial("Per-query timeout option interrupts long-running Statement.all()", async (t) => {
+ const [db, errorType] = await connect(":memory:");
+ const stmt = await db.prepare(
+ "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
+ );
+
+ await t.throwsAsync(async () => {
+ await stmt.all(undefined, { queryTimeout: 100 });
+ }, {
+ instanceOf: errorType,
+ message: "interrupted",
+ code: "SQLITE_INTERRUPT",
+ });
+
+ db.close();
+});
+
+test.serial("Per-query timeout option is accepted by Database.exec()", async (t) => {
+ const [db] = await connect(":memory:");
+ await db.exec("SELECT 1", { queryTimeout: 100 });
+ t.pass();
+
+ db.close();
+});
+
test.serial("Concurrent writes over same connection", async (t) => {
const db = t.context.db;
await db.exec(`
diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js
index 5208e8d..8bbf18e 100644
--- a/integration-tests/tests/sync.test.js
+++ b/integration-tests/tests/sync.test.js
@@ -457,6 +457,75 @@ test.serial("Timeout option", async (t) => {
fs.unlinkSync(path);
});
+test.serial("Query timeout option interrupts long-running query", async (t) => {
+ if (t.context.provider === "sqlite") {
+ t.assert(true);
+ return;
+ }
+
+ const [db, errorType] = await connect(":memory:", { defaultQueryTimeout: 100 });
+ const stmt = db.prepare(
+ "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
+ );
+
+ t.throws(() => {
+ stmt.all();
+ }, {
+ instanceOf: errorType,
+ message: "interrupted",
+ code: "SQLITE_INTERRUPT",
+ });
+
+ db.close();
+});
+
+test.serial("Query timeout option allows short-running query", async (t) => {
+ if (t.context.provider === "sqlite") {
+ t.assert(true);
+ return;
+ }
+
+ const [db] = await connect(":memory:", { defaultQueryTimeout: 100 });
+ const stmt = db.prepare("SELECT 1 AS value");
+ t.deepEqual(stmt.get(), { value: 1 });
+ db.close();
+});
+
+test.serial("Per-query timeout option interrupts long-running Statement.all()", async (t) => {
+ if (t.context.provider === "sqlite") {
+ t.assert(true);
+ return;
+ }
+
+ const [db, errorType] = await connect(":memory:");
+ const stmt = db.prepare(
+ "WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
+ );
+
+ t.throws(() => {
+ stmt.all(undefined, { queryTimeout: 100 });
+ }, {
+ instanceOf: errorType,
+ message: "interrupted",
+ code: "SQLITE_INTERRUPT",
+ });
+
+ db.close();
+});
+
+test.serial("Per-query timeout option is accepted by Database.exec()", async (t) => {
+ if (t.context.provider === "sqlite") {
+ t.assert(true);
+ return;
+ }
+
+ const [db] = await connect(":memory:");
+ db.exec("SELECT 1", { queryTimeout: 100 });
+ t.pass();
+
+ db.close();
+});
+
test.serial("Statement.reader [SELECT is true]", async (t) => {
const db = t.context.db;
const stmt = db.prepare("SELECT * FROM users WHERE id = ?");
diff --git a/promise.js b/promise.js
index 8a79863..f364013 100644
--- a/promise.js
+++ b/promise.js
@@ -41,6 +41,26 @@ function convertError(err) {
return err;
}
+function isQueryOptions(value) {
+ return value != null
+ && typeof value === "object"
+ && !Array.isArray(value)
+ && Object.prototype.hasOwnProperty.call(value, "queryTimeout");
+}
+
+function splitBindParameters(bindParameters) {
+ if (bindParameters.length === 0) {
+ return { params: undefined, queryOptions: undefined };
+ }
+ if (bindParameters.length > 1 && isQueryOptions(bindParameters[bindParameters.length - 1])) {
+ return {
+ params: bindParameters.length === 2 ? bindParameters[0] : bindParameters.slice(0, -1),
+ queryOptions: bindParameters[bindParameters.length - 1],
+ };
+ }
+ return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined };
+}
+
/**
* Creates a new database connection.
*
@@ -217,9 +237,9 @@ class Database {
*
* @param {string} sql - The SQL statement string to execute.
*/
- async exec(sql) {
+ async exec(sql, queryOptions) {
try {
- await this.db.exec(sql);
+ await this.db.exec(sql, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -308,7 +328,8 @@ class Statement {
*/
async run(...bindParameters) {
try {
- return await this.stmt.run(...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ return await this.stmt.run(params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -321,7 +342,8 @@ class Statement {
*/
async get(...bindParameters) {
try {
- return await this.stmt.get(...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ return await this.stmt.get(params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -334,7 +356,8 @@ class Statement {
*/
async iterate(...bindParameters) {
try {
- const it = await this.stmt.iterate(...bindParameters);
+ const { params, queryOptions } = splitBindParameters(bindParameters);
+ const it = await this.stmt.iterate(params, queryOptions);
return {
next() {
return it.next();
diff --git a/src/lib.rs b/src/lib.rs
index 0dbdf9d..61e1581 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -21,6 +21,7 @@
#![allow(deprecated)]
mod auth;
+mod query_timeout;
use napi::{
bindgen_prelude::{Array, FromNapiValue, ToNapiValue},
@@ -28,6 +29,7 @@ use napi::{
};
use napi_derive::napi;
use once_cell::sync::OnceCell;
+use query_timeout::{QueryTimeoutManager, TimeoutGuard};
use std::{
str::FromStr,
sync::{
@@ -200,6 +202,15 @@ pub struct Options {
pub encryptionKey: Option,
// Encryption key for remote encryption at rest.
pub remoteEncryptionKey: Option,
+ // Default maximum time in milliseconds that a query is allowed to run.
+ pub defaultQueryTimeout: Option,
+}
+
+/// Per-query execution options.
+#[napi(object)]
+pub struct QueryOptions {
+ // Maximum time in milliseconds that this query is allowed to run.
+ pub queryTimeout: Option,
}
/// Access mode.
@@ -224,10 +235,15 @@ pub struct Database {
default_safe_integers: AtomicBool,
// Whether to use memory-only mode.
memory: bool,
+ // Maximum time in milliseconds that a query is allowed to run.
+ query_timeout: Option,
+ // Shared timeout manager for efficient query timeout handling.
+ timeout_manager: Arc,
}
impl Drop for Database {
fn drop(&mut self) {
+ self.timeout_manager.shutdown();
self.conn = None;
self.db = None;
}
@@ -321,11 +337,18 @@ pub async fn connect(path: String, opts: Option) -> Result {
conn.busy_timeout(Duration::from_millis(timeout as u64))
.map_err(Error::from)?
}
+ let query_timeout = opts
+ .as_ref()
+ .and_then(|o| o.defaultQueryTimeout)
+ .and_then(query_timeout_duration);
+ let timeout_manager = Arc::new(QueryTimeoutManager::new());
Ok(Database {
db: Some(db),
conn: Some(Arc::new(conn)),
default_safe_integers,
memory,
+ query_timeout,
+ timeout_manager,
})
}
@@ -388,7 +411,13 @@ impl Database {
pluck: false.into(),
timing: false.into(),
};
- Ok(Statement::new(conn, stmt, mode))
+ Ok(Statement::new(
+ conn,
+ stmt,
+ mode,
+ self.query_timeout,
+ self.timeout_manager.clone(),
+ ))
}
/// Sets the authorizer for the database.
@@ -509,7 +538,7 @@ impl Database {
/// * `env` - The environment.
/// * `sql` - The SQL statement to execute.
#[napi]
- pub async fn exec(&self, sql: String) -> Result<()> {
+ pub async fn exec(&self, sql: String, query_options: Option) -> Result<()> {
let conn = match &self.conn {
Some(conn) => conn.clone(),
None => {
@@ -520,6 +549,11 @@ impl Database {
));
}
};
+ let query_timeout = match query_options.and_then(|o| o.queryTimeout) {
+ Some(timeout_ms) => query_timeout_duration(timeout_ms),
+ None => self.query_timeout,
+ };
+ let _guard = query_timeout.map(|t| self.timeout_manager.register(&conn, t));
conn.execute_batch(&sql).await.map_err(Error::from)?;
Ok(())
}
@@ -566,6 +600,7 @@ impl Database {
/// Closes the database connection.
#[napi]
pub fn close(&mut self) -> Result<()> {
+ self.timeout_manager.shutdown();
self.conn = None;
self.db = None;
Ok(())
@@ -609,9 +644,13 @@ pub fn database_sync_sync(db: &Database) -> Result {
/// Executes SQL in blocking mode.
#[napi]
-pub fn database_exec_sync(db: &Database, sql: String) -> Result<()> {
+pub fn database_exec_sync(
+ db: &Database,
+ sql: String,
+ query_options: Option,
+) -> Result<()> {
let rt = runtime()?;
- rt.block_on(async move { db.exec(sql).await })
+ rt.block_on(async move { db.exec(sql, query_options).await })
}
fn is_remote_path(path: &str) -> bool {
@@ -625,6 +664,14 @@ fn throw_database_closed_error(env: &Env) -> napi::Error {
err
}
+fn query_timeout_duration(timeout_ms: f64) -> Option {
+ if timeout_ms.is_finite() && timeout_ms > 0.0 {
+ Some(Duration::from_millis(timeout_ms as u64))
+ } else {
+ None
+ }
+}
+
/// SQLite statement object.
#[napi]
pub struct Statement {
@@ -636,6 +683,10 @@ pub struct Statement {
column_names: Vec,
// The access mode.
mode: AccessMode,
+ // Maximum time in milliseconds that a query is allowed to run.
+ query_timeout: Option,
+ // Shared timeout manager.
+ timeout_manager: Arc,
}
#[napi]
@@ -651,6 +702,8 @@ impl Statement {
conn: Arc,
stmt: libsql::Statement,
mode: AccessMode,
+ query_timeout: Option,
+ timeout_manager: Arc,
) -> Self {
let column_names: Vec = stmt
.columns()
@@ -663,6 +716,8 @@ impl Statement {
stmt,
column_names,
mode,
+ query_timeout,
+ timeout_manager,
}
}
@@ -672,15 +727,22 @@ impl Statement {
///
/// * `params` - The parameters to bind to the statement.
#[napi]
- pub fn run(&self, env: Env, params: Option) -> Result {
+ pub fn run(
+ &self,
+ env: Env,
+ params: Option,
+ query_options: Option,
+ ) -> Result {
self.stmt.reset();
let params = map_params(&self.stmt, params)?;
let total_changes_before = self.conn.total_changes();
let start = std::time::Instant::now();
let stmt = self.stmt.clone();
let conn = self.conn.clone();
+ let guard = self.start_timeout_guard(query_options);
let future = async move {
+ let _guard = guard;
stmt.run(params).await.map_err(Error::from)?;
let changes = if conn.total_changes() == total_changes_before {
0
@@ -706,7 +768,12 @@ impl Statement {
/// * `env` - The environment.
/// * `params` - The parameters to bind to the statement.
#[napi]
- pub fn get(&self, env: Env, params: Option) -> Result {
+ pub fn get(
+ &self,
+ env: Env,
+ params: Option,
+ query_options: Option,
+ ) -> Result {
let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst);
let raw = self.mode.raw.load(Ordering::SeqCst);
let pluck = self.mode.pluck.load(Ordering::SeqCst);
@@ -723,7 +790,9 @@ impl Statement {
};
let stmt_fut = stmt.clone();
+ let guard = self.start_timeout_guard(query_options);
let future = async move {
+ let _guard = guard;
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
let row = rows.next().await.map_err(Error::from)?;
let duration: Option = start.map(|start| start.elapsed().as_secs_f64());
@@ -779,7 +848,12 @@ impl Statement {
/// * `env` - The environment.
/// * `params` - The parameters to bind to the statement.
#[napi]
- pub fn iterate(&self, env: Env, params: Option) -> Result {
+ pub fn iterate(
+ &self,
+ env: Env,
+ params: Option,
+ query_options: Option,
+ ) -> Result {
let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst);
let raw = self.mode.raw.load(Ordering::SeqCst);
let pluck = self.mode.pluck.load(Ordering::SeqCst);
@@ -787,6 +861,7 @@ impl Statement {
stmt.reset();
let params = map_params(&stmt, params).unwrap();
let stmt = self.stmt.clone();
+ let guard = self.start_timeout_guard(query_options);
let future = async move {
let rows = stmt.query(params).await.map_err(Error::from)?;
Ok::<_, napi::Error>(rows)
@@ -799,6 +874,7 @@ impl Statement {
safe_ints,
raw,
pluck,
+ guard,
))
})
}
@@ -882,12 +958,27 @@ impl Statement {
}
}
+impl Statement {
+ fn resolve_query_timeout(&self, query_options: Option) -> Option {
+ match query_options.and_then(|o| o.queryTimeout) {
+ Some(timeout_ms) => query_timeout_duration(timeout_ms),
+ None => self.query_timeout,
+ }
+ }
+
+ fn start_timeout_guard(&self, query_options: Option) -> Option {
+ self.resolve_query_timeout(query_options)
+ .map(|t| self.timeout_manager.register(&self.conn, t))
+ }
+}
+
/// Gets first row from statement in blocking mode.
#[napi]
pub fn statement_get_sync(
stmt: &Statement,
env: Env,
params: Option,
+ query_options: Option,
) -> Result {
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
let raw = stmt.mode.raw.load(Ordering::SeqCst);
@@ -901,6 +992,7 @@ pub fn statement_get_sync(
};
let rt = runtime()?;
+ let _guard = stmt.start_timeout_guard(query_options);
rt.block_on(async move {
let params = map_params(&stmt.stmt, params)?;
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
@@ -922,9 +1014,14 @@ pub fn statement_get_sync(
/// Runs a statement in blocking mode.
#[napi]
-pub fn statement_run_sync(stmt: &Statement, params: Option) -> Result {
+pub fn statement_run_sync(
+ stmt: &Statement,
+ params: Option,
+ query_options: Option,
+) -> Result {
stmt.stmt.reset();
let rt = runtime()?;
+ let _guard = stmt.start_timeout_guard(query_options);
rt.block_on(async move {
let params = map_params(&stmt.stmt, params)?;
let total_changes_before = stmt.conn.total_changes();
@@ -951,16 +1048,18 @@ pub fn statement_iterate_sync(
stmt: &Statement,
_env: Env,
params: Option,
+ query_options: Option,
) -> Result {
let rt = runtime()?;
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
let raw = stmt.mode.raw.load(Ordering::SeqCst);
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
- let stmt = stmt.stmt.clone();
+ let guard = stmt.start_timeout_guard(query_options);
+ let inner_stmt = stmt.stmt.clone();
let (rows, column_names) = rt.block_on(async move {
- stmt.reset();
- let params = map_params(&stmt, params)?;
- let rows = stmt.query(params).await.map_err(Error::from)?;
+ inner_stmt.reset();
+ let params = map_params(&inner_stmt, params)?;
+ let rows = inner_stmt.query(params).await.map_err(Error::from)?;
let mut column_names = Vec::new();
for i in 0..rows.column_count() {
column_names
@@ -974,6 +1073,7 @@ pub fn statement_iterate_sync(
safe_ints,
raw,
pluck,
+ guard,
))
}
@@ -1120,6 +1220,7 @@ pub struct RowsIterator {
safe_ints: bool,
raw: bool,
pluck: bool,
+ _timeout_guard: Option,
}
#[napi]
@@ -1130,6 +1231,7 @@ impl RowsIterator {
safe_ints: bool,
raw: bool,
pluck: bool,
+ timeout_guard: Option,
) -> Self {
Self {
rows,
@@ -1137,6 +1239,7 @@ impl RowsIterator {
safe_ints,
raw,
pluck,
+ _timeout_guard: timeout_guard,
}
}
diff --git a/src/query_timeout.rs b/src/query_timeout.rs
new file mode 100644
index 0000000..7ed504a
--- /dev/null
+++ b/src/query_timeout.rs
@@ -0,0 +1,347 @@
+use std::{
+ cmp::Reverse,
+ collections::BinaryHeap,
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc, Mutex, Weak,
+ },
+ time::Duration,
+};
+use tokio::{sync::Notify, time::Instant};
+
+/// A single-background-task timer wheel that interrupts connections when their
+/// query deadline expires. Registering a query returns a [`TimeoutGuard`] —
+/// dropping the guard cancels the timeout.
+pub struct QueryTimeoutManager {
+ inner: Arc,
+}
+
+struct Inner {
+ entries: Mutex,
+ /// Wakes the background task when the earliest deadline changes.
+ notify: Arc,
+ /// Set to `true` to make the background task exit.
+ shutdown: AtomicBool,
+}
+
+struct Entries {
+ heap: BinaryHeap>,
+ next_id: u64,
+}
+
+#[derive(Clone)]
+struct Entry {
+ id: u64,
+ deadline: Instant,
+ conn: Weak,
+ /// Cleared when the guard is dropped (query finished in time).
+ active: Arc,
+}
+
+impl PartialEq for Entry {
+ fn eq(&self, other: &Self) -> bool {
+ self.deadline == other.deadline && self.id == other.id
+ }
+}
+impl Eq for Entry {}
+impl PartialOrd for Entry {
+ fn partial_cmp(&self, other: &Self) -> Option {
+ Some(self.cmp(other))
+ }
+}
+impl Ord for Entry {
+ fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+ self.deadline
+ .cmp(&other.deadline)
+ .then(self.id.cmp(&other.id))
+ }
+}
+
+impl QueryTimeoutManager {
+ pub fn new() -> Self {
+ let inner = Arc::new(Inner {
+ entries: Mutex::new(Entries {
+ heap: BinaryHeap::new(),
+ next_id: 0,
+ }),
+ notify: Arc::new(Notify::new()),
+ shutdown: AtomicBool::new(false),
+ });
+ let bg = Arc::downgrade(&inner);
+ tokio::spawn(async move {
+ Self::background_task(bg).await;
+ });
+ Self { inner }
+ }
+
+ /// Synchronously remove all entries from the heap, releasing any
+ /// connection references the background task is holding.
+ pub fn clear(&self) {
+ let mut entries = self.inner.entries.lock().unwrap();
+ entries.heap.clear();
+ }
+
+ /// Signal the background task to exit and clear all entries.
+ pub fn shutdown(&self) {
+ self.inner.shutdown.store(true, Ordering::Relaxed);
+ self.clear();
+ self.inner.notify.notify_one();
+ }
+
+ /// Register a query. The returned guard must be held for the duration of
+ /// the query — dropping it cancels the timeout.
+ pub fn register(&self, conn: &Arc, timeout: Duration) -> TimeoutGuard {
+ let active = Arc::new(AtomicBool::new(true));
+ let mut entries = self.inner.entries.lock().unwrap();
+ let id = entries.next_id;
+ entries.next_id += 1;
+ let deadline = Instant::now()
+ .checked_add(timeout)
+ .unwrap_or_else(|| Instant::now() + Duration::from_secs(86400));
+ let entry = Entry {
+ id,
+ deadline,
+ conn: Arc::downgrade(conn),
+ active: active.clone(),
+ };
+ let is_new_earliest = entries
+ .heap
+ .peek()
+ .map_or(true, |Reverse(e)| entry.deadline < e.deadline);
+ entries.heap.push(Reverse(entry));
+ drop(entries);
+ if is_new_earliest {
+ self.inner.notify.notify_one();
+ }
+ TimeoutGuard {
+ active,
+ notify: self.inner.notify.clone(),
+ }
+ }
+
+ async fn background_task(weak: Weak) {
+ loop {
+ let inner = match weak.upgrade() {
+ Some(inner) => inner,
+ None => return, // Manager dropped — exit.
+ };
+
+ if inner.shutdown.load(Ordering::Relaxed) {
+ return;
+ }
+
+ // Find the next deadline, skipping cancelled entries.
+ let next = {
+ let mut entries = inner.entries.lock().unwrap();
+ loop {
+ match entries.heap.peek() {
+ Some(Reverse(e)) if !e.active.load(Ordering::Relaxed) => {
+ entries.heap.pop();
+ }
+ Some(Reverse(e)) => break Some(e.clone()),
+ None => break None,
+ }
+ }
+ };
+
+ match next {
+ Some(entry) => {
+ tokio::select! {
+ _ = tokio::time::sleep_until(entry.deadline) => {
+ // Deadline reached — interrupt if still active.
+ if entry.active.load(Ordering::Relaxed) {
+ if let Some(conn) = entry.conn.upgrade() {
+ let _ = conn.interrupt();
+ }
+ }
+ // Remove this entry.
+ let mut entries = inner.entries.lock().unwrap();
+ // Pop entries that are done (expired or cancelled).
+ while let Some(Reverse(e)) = entries.heap.peek() {
+ if !e.active.load(Ordering::Relaxed) || e.id == entry.id {
+ entries.heap.pop();
+ } else {
+ break;
+ }
+ }
+ }
+ _ = inner.notify.notified() => {
+ // A new earlier deadline was added; re-check.
+ }
+ }
+ }
+ None => {
+ // Nothing to do — wait until a new entry is registered.
+ // Must hold the Arc while waiting so we can detect drop next iteration.
+ inner.notify.notified().await;
+ }
+ }
+ }
+ }
+}
+
+impl Drop for QueryTimeoutManager {
+ fn drop(&mut self) {
+ // Signal the background task to exit.
+ self.inner.shutdown.store(true, Ordering::Relaxed);
+ self.inner.notify.notify_one();
+ }
+}
+
+/// Dropping this guard cancels the associated query timeout.
+pub struct TimeoutGuard {
+ active: Arc,
+ notify: Arc,
+}
+
+impl Drop for TimeoutGuard {
+ fn drop(&mut self) {
+ self.active.store(false, Ordering::Relaxed);
+ // Wake the background task so it can clean up the cancelled entry.
+ self.notify.notify_one();
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::sync::Arc;
+
+ async fn test_conn() -> Arc {
+ let db = libsql::Builder::new_local(":memory:")
+ .build()
+ .await
+ .unwrap();
+ Arc::new(db.connect().unwrap())
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn deadline_expires_interrupts_connection() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+
+ // Register a 200ms timeout, then start an infinite query.
+ let _guard = mgr.register(&conn, Duration::from_millis(200));
+ let fut = {
+ let conn = conn.clone();
+ tokio::spawn(async move {
+ conn.execute_batch(
+ "WITH RECURSIVE r(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM r) SELECT * FROM r",
+ )
+ .await
+ })
+ };
+
+ // The query should have been interrupted by the timeout.
+ let result = fut.await.unwrap();
+ assert!(result.is_err(), "query should have been interrupted");
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn guard_dropped_before_deadline_cancels_timeout() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+
+ let guard = mgr.register(&conn, Duration::from_millis(200));
+
+ // Query "finishes" before the deadline.
+ drop(guard);
+
+ // Wait past where the deadline would have been.
+ tokio::time::sleep(Duration::from_millis(300)).await;
+
+ // Connection should still work — no spurious interrupt.
+ let result = conn.execute_batch("SELECT 1").await;
+ assert!(
+ result.is_ok(),
+ "connection should not have been interrupted"
+ );
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn guard_drop_cleans_entry_from_heap() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+
+ let guard = mgr.register(&conn, Duration::from_millis(500));
+
+ drop(guard);
+ // Let the background task wake up and clean the cancelled entry.
+ tokio::time::sleep(Duration::from_millis(50)).await;
+
+ let entries = mgr.inner.entries.lock().unwrap();
+ assert_eq!(
+ entries.heap.len(),
+ 0,
+ "dropping guard should clean up the entry from the heap"
+ );
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn multiple_deadlines_fire_in_order() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+
+ // Register three timeouts at 100ms, 200ms, 300ms.
+ let _g1 = mgr.register(&conn, Duration::from_millis(100));
+ let _g2 = mgr.register(&conn, Duration::from_millis(200));
+ let _g3 = mgr.register(&conn, Duration::from_millis(300));
+
+ // After 150ms, only the first should have fired.
+ tokio::time::sleep(Duration::from_millis(150)).await;
+ {
+ let entries = mgr.inner.entries.lock().unwrap();
+ assert_eq!(entries.heap.len(), 2, "only first entry should have fired");
+ }
+
+ // After 350ms total, all three should have fired.
+ tokio::time::sleep(Duration::from_millis(200)).await;
+ {
+ let entries = mgr.inner.entries.lock().unwrap();
+ assert_eq!(entries.heap.len(), 0, "all entries should be cleaned up");
+ }
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn new_earlier_deadline_preempts_existing() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+
+ // Register a long timeout first, then a shorter one that should preempt.
+ let _g1 = mgr.register(&conn, Duration::from_millis(5000));
+ let _g2 = mgr.register(&conn, Duration::from_millis(200));
+
+ let fut = {
+ let conn = conn.clone();
+ tokio::spawn(async move {
+ conn.execute_batch(
+ "WITH RECURSIVE r(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM r) SELECT * FROM r",
+ )
+ .await
+ })
+ };
+
+ let result = fut.await.unwrap();
+ assert!(
+ result.is_err(),
+ "shorter deadline should have interrupted the query"
+ );
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+ #[ntest::timeout(10000)]
+ async fn background_task_exits_when_manager_dropped() {
+ let conn = test_conn().await;
+ let mgr = QueryTimeoutManager::new();
+ let guard = mgr.register(&conn, Duration::from_millis(5000));
+ drop(guard);
+ drop(mgr);
+ // If the background task didn't exit, it would leak — verify no panic.
+ tokio::time::sleep(Duration::from_millis(50)).await;
+ }
+}