Skip to content

Commit 432e757

Browse files
committed
Add query timeout option to interrupt long-running queries
A single background tokio task with a min-heap manages all query deadlines efficiently. When a query starts, a TimeoutGuard is acquired; if the deadline expires before the guard is dropped, the connection is interrupted via sqlite3_interrupt().
1 parent fe57952 commit 432e757

3 files changed

Lines changed: 204 additions & 5 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte
2222
- `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds.
2323
- `authToken`: authentication token for the provider URL (optional).
2424
- `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error
25+
- `queryTimeout`: maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error
2526

2627
The function returns a `Database` object.
2728

src/lib.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#![allow(deprecated)]
2222

2323
mod auth;
24+
mod query_timeout;
2425

2526
use napi::{
2627
bindgen_prelude::{Array, FromNapiValue, ToNapiValue},
@@ -200,6 +201,8 @@ pub struct Options {
200201
pub encryptionKey: Option<String>,
201202
// Encryption key for remote encryption at rest.
202203
pub remoteEncryptionKey: Option<String>,
204+
// Maximum time in milliseconds that a query is allowed to run.
205+
pub queryTimeout: Option<f64>,
203206
}
204207

205208
/// Access mode.
@@ -224,6 +227,10 @@ pub struct Database {
224227
default_safe_integers: AtomicBool,
225228
// Whether to use memory-only mode.
226229
memory: bool,
230+
// Maximum time in milliseconds that a query is allowed to run.
231+
query_timeout: Option<Duration>,
232+
// Shared timeout manager for efficient query timeout handling.
233+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
227234
}
228235

229236
impl Drop for Database {
@@ -320,11 +327,19 @@ pub async fn connect(path: String, opts: Option<Options>) -> Result<Database> {
320327
conn.busy_timeout(Duration::from_millis(timeout as u64))
321328
.map_err(Error::from)?
322329
}
330+
let query_timeout = opts
331+
.as_ref()
332+
.and_then(|o| o.queryTimeout)
333+
.filter(|&t| t > 0.0)
334+
.map(|t| Duration::from_millis(t as u64));
335+
let timeout_manager = Arc::new(query_timeout::QueryTimeoutManager::new());
323336
Ok(Database {
324337
db,
325338
conn: Some(Arc::new(conn)),
326339
default_safe_integers,
327340
memory,
341+
query_timeout,
342+
timeout_manager,
328343
})
329344
}
330345

@@ -387,7 +402,7 @@ impl Database {
387402
pluck: false.into(),
388403
timing: false.into(),
389404
};
390-
Ok(Statement::new(conn, stmt, mode))
405+
Ok(Statement::new(conn, stmt, mode, self.query_timeout, self.timeout_manager.clone()))
391406
}
392407

393408
/// Sets the authorizer for the database.
@@ -515,6 +530,7 @@ impl Database {
515530
));
516531
}
517532
};
533+
let _guard = self.query_timeout.map(|t| self.timeout_manager.register(&conn, t));
518534
conn.execute_batch(&sql).await.map_err(Error::from)?;
519535
Ok(())
520536
}
@@ -620,6 +636,10 @@ pub struct Statement {
620636
column_names: Vec<std::ffi::CString>,
621637
// The access mode.
622638
mode: AccessMode,
639+
// Maximum time in milliseconds that a query is allowed to run.
640+
query_timeout: Option<Duration>,
641+
// Shared timeout manager.
642+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
623643
}
624644

625645
#[napi]
@@ -635,6 +655,8 @@ impl Statement {
635655
conn: Arc<libsql::Connection>,
636656
stmt: libsql::Statement,
637657
mode: AccessMode,
658+
query_timeout: Option<Duration>,
659+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
638660
) -> Self {
639661
let column_names: Vec<std::ffi::CString> = stmt
640662
.columns()
@@ -647,6 +669,8 @@ impl Statement {
647669
stmt,
648670
column_names,
649671
mode,
672+
query_timeout,
673+
timeout_manager,
650674
}
651675
}
652676

@@ -663,8 +687,10 @@ impl Statement {
663687
let start = std::time::Instant::now();
664688
let stmt = self.stmt.clone();
665689
let conn = self.conn.clone();
690+
let guard = self.start_timeout_guard();
666691

667692
let future = async move {
693+
let _guard = guard;
668694
stmt.run(params).await.map_err(Error::from)?;
669695
let changes = if conn.total_changes() == total_changes_before {
670696
0
@@ -707,7 +733,9 @@ impl Statement {
707733
};
708734

709735
let stmt_fut = stmt.clone();
736+
let guard = self.start_timeout_guard();
710737
let future = async move {
738+
let _guard = guard;
711739
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
712740
let row = rows.next().await.map_err(Error::from)?;
713741
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
@@ -771,6 +799,7 @@ impl Statement {
771799
stmt.reset();
772800
let params = map_params(&stmt, params).unwrap();
773801
let stmt = self.stmt.clone();
802+
let guard = self.start_timeout_guard();
774803
let future = async move {
775804
let rows = stmt.query(params).await.map_err(Error::from)?;
776805
Ok::<_, napi::Error>(rows)
@@ -783,6 +812,7 @@ impl Statement {
783812
safe_ints,
784813
raw,
785814
pluck,
815+
guard,
786816
))
787817
})
788818
}
@@ -864,6 +894,13 @@ impl Statement {
864894
self.stmt.interrupt().map_err(Error::from)?;
865895
Ok(())
866896
}
897+
898+
}
899+
900+
impl Statement {
901+
fn start_timeout_guard(&self) -> Option<query_timeout::TimeoutGuard> {
902+
self.query_timeout.map(|t| self.timeout_manager.register(&self.conn, t))
903+
}
867904
}
868905

869906
/// Gets first row from statement in blocking mode.
@@ -885,6 +922,7 @@ pub fn statement_get_sync(
885922
};
886923

887924
let rt = runtime()?;
925+
let _guard = stmt.start_timeout_guard();
888926
rt.block_on(async move {
889927
let params = map_params(&stmt.stmt, params)?;
890928
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
@@ -909,6 +947,7 @@ pub fn statement_get_sync(
909947
pub fn statement_run_sync(stmt: &Statement, params: Option<napi::JsUnknown>) -> Result<RunResult> {
910948
stmt.stmt.reset();
911949
let rt = runtime()?;
950+
let _guard = stmt.start_timeout_guard();
912951
rt.block_on(async move {
913952
let params = map_params(&stmt.stmt, params)?;
914953
let total_changes_before = stmt.conn.total_changes();
@@ -940,11 +979,12 @@ pub fn statement_iterate_sync(
940979
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
941980
let raw = stmt.mode.raw.load(Ordering::SeqCst);
942981
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
943-
let stmt = stmt.stmt.clone();
982+
let guard = stmt.start_timeout_guard();
983+
let inner_stmt = stmt.stmt.clone();
944984
let (rows, column_names) = rt.block_on(async move {
945-
stmt.reset();
946-
let params = map_params(&stmt, params)?;
947-
let rows = stmt.query(params).await.map_err(Error::from)?;
985+
inner_stmt.reset();
986+
let params = map_params(&inner_stmt, params)?;
987+
let rows = inner_stmt.query(params).await.map_err(Error::from)?;
948988
let mut column_names = Vec::new();
949989
for i in 0..rows.column_count() {
950990
column_names
@@ -958,6 +998,7 @@ pub fn statement_iterate_sync(
958998
safe_ints,
959999
raw,
9601000
pluck,
1001+
guard,
9611002
))
9621003
}
9631004

@@ -1104,6 +1145,7 @@ pub struct RowsIterator {
11041145
safe_ints: bool,
11051146
raw: bool,
11061147
pluck: bool,
1148+
_timeout_guard: Option<query_timeout::TimeoutGuard>,
11071149
}
11081150

11091151
#[napi]
@@ -1114,13 +1156,15 @@ impl RowsIterator {
11141156
safe_ints: bool,
11151157
raw: bool,
11161158
pluck: bool,
1159+
timeout_guard: Option<query_timeout::TimeoutGuard>,
11171160
) -> Self {
11181161
Self {
11191162
rows,
11201163
column_names,
11211164
safe_ints,
11221165
raw,
11231166
pluck,
1167+
_timeout_guard: timeout_guard,
11241168
}
11251169
}
11261170

src/query_timeout.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
use std::collections::BinaryHeap;
2+
use std::cmp::Reverse;
3+
use std::sync::{Arc, Mutex};
4+
use std::time::Duration;
5+
use tokio::sync::Notify;
6+
use tokio::time::Instant;
7+
8+
/// A single-background-task timer wheel that interrupts connections when their
9+
/// query deadline expires. Registering a query returns a [`TimeoutGuard`] —
10+
/// dropping the guard cancels the timeout.
11+
pub struct QueryTimeoutManager {
12+
inner: Arc<Inner>,
13+
}
14+
15+
struct Inner {
16+
entries: Mutex<Entries>,
17+
/// Wakes the background task when the earliest deadline changes.
18+
notify: Notify,
19+
}
20+
21+
struct Entries {
22+
heap: BinaryHeap<Reverse<Entry>>,
23+
next_id: u64,
24+
}
25+
26+
#[derive(Clone)]
27+
struct Entry {
28+
id: u64,
29+
deadline: Instant,
30+
conn: Arc<libsql::Connection>,
31+
/// Cleared when the guard is dropped (query finished in time).
32+
active: Arc<AtomicBool>,
33+
}
34+
35+
use std::sync::atomic::AtomicBool;
36+
use std::sync::atomic::Ordering;
37+
38+
impl PartialEq for Entry {
39+
fn eq(&self, other: &Self) -> bool {
40+
self.deadline == other.deadline && self.id == other.id
41+
}
42+
}
43+
impl Eq for Entry {}
44+
impl PartialOrd for Entry {
45+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
46+
Some(self.cmp(other))
47+
}
48+
}
49+
impl Ord for Entry {
50+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51+
self.deadline.cmp(&other.deadline).then(self.id.cmp(&other.id))
52+
}
53+
}
54+
55+
impl QueryTimeoutManager {
56+
pub fn new() -> Self {
57+
let inner = Arc::new(Inner {
58+
entries: Mutex::new(Entries {
59+
heap: BinaryHeap::new(),
60+
next_id: 0,
61+
}),
62+
notify: Notify::new(),
63+
});
64+
let bg = inner.clone();
65+
tokio::spawn(async move {
66+
Self::background_task(bg).await;
67+
});
68+
Self { inner }
69+
}
70+
71+
/// Register a query. The returned guard must be held for the duration of
72+
/// the query — dropping it cancels the timeout.
73+
pub fn register(&self, conn: &Arc<libsql::Connection>, timeout: Duration) -> TimeoutGuard {
74+
let active = Arc::new(AtomicBool::new(true));
75+
let mut entries = self.inner.entries.lock().unwrap();
76+
let id = entries.next_id;
77+
entries.next_id += 1;
78+
let entry = Entry {
79+
id,
80+
deadline: Instant::now() + timeout,
81+
conn: conn.clone(),
82+
active: active.clone(),
83+
};
84+
let is_new_earliest = entries
85+
.heap
86+
.peek()
87+
.map_or(true, |Reverse(e)| entry.deadline < e.deadline);
88+
entries.heap.push(Reverse(entry));
89+
drop(entries);
90+
if is_new_earliest {
91+
self.inner.notify.notify_one();
92+
}
93+
TimeoutGuard { active }
94+
}
95+
96+
async fn background_task(inner: Arc<Inner>) {
97+
loop {
98+
// Find the next deadline, skipping cancelled entries.
99+
let next = {
100+
let mut entries = inner.entries.lock().unwrap();
101+
loop {
102+
match entries.heap.peek() {
103+
Some(Reverse(e)) if !e.active.load(Ordering::Relaxed) => {
104+
entries.heap.pop();
105+
}
106+
Some(Reverse(e)) => break Some(e.clone()),
107+
None => break None,
108+
}
109+
}
110+
};
111+
112+
match next {
113+
Some(entry) => {
114+
tokio::select! {
115+
_ = tokio::time::sleep_until(entry.deadline) => {
116+
// Deadline reached — interrupt if still active.
117+
if entry.active.load(Ordering::Relaxed) {
118+
let _ = entry.conn.interrupt();
119+
}
120+
// Remove this entry.
121+
let mut entries = inner.entries.lock().unwrap();
122+
// Pop entries that are done (expired or cancelled).
123+
while let Some(Reverse(e)) = entries.heap.peek() {
124+
if !e.active.load(Ordering::Relaxed) || e.id == entry.id {
125+
entries.heap.pop();
126+
} else {
127+
break;
128+
}
129+
}
130+
}
131+
_ = inner.notify.notified() => {
132+
// A new earlier deadline was added; re-check.
133+
}
134+
}
135+
}
136+
None => {
137+
// Nothing to do — wait until a new entry is registered.
138+
inner.notify.notified().await;
139+
}
140+
}
141+
}
142+
}
143+
}
144+
145+
/// Dropping this guard cancels the associated query timeout.
146+
pub struct TimeoutGuard {
147+
active: Arc<AtomicBool>,
148+
}
149+
150+
impl Drop for TimeoutGuard {
151+
fn drop(&mut self) {
152+
self.active.store(false, Ordering::Relaxed);
153+
}
154+
}

0 commit comments

Comments
 (0)