diff --git a/compat.js b/compat.js index 8cac2d8..5fc1fc1 100644 --- a/compat.js +++ b/compat.js @@ -309,14 +309,7 @@ class Statement { try { const { params, queryOptions } = splitBindParameters(bindParameters); const it = statementIterateSync(this.stmt, params, queryOptions); - return { - next: () => iteratorNextSync(it), - [Symbol.iterator]() { - return { - next: () => iteratorNextSync(it), - } - }, - }; + return wrappedIter(it); } catch (err) { throw convertError(err); } @@ -331,11 +324,17 @@ class Statement { try { const result = []; const iterator = this.iterate(...bindParameters); - let next; - while (!(next = iterator.next()).done) { - result.push(next.value); + try { + let next; + while (!(next = iterator.next()).done) { + result.push(next.value); + } + return result; + } finally { + if (typeof iterator.return === "function") { + iterator.return(); + } } - return result; } catch (err) { throw convertError(err); } @@ -365,6 +364,26 @@ class Statement { } } +function wrappedIter(it) { + return { + next() { + return iteratorNextSync(it); + }, + return(value) { + if (typeof it.close === "function") { + it.close(); + } + return { + done: true, + value, + }; + }, + [Symbol.iterator]() { + return this; + }, + }; +} + module.exports = Database; module.exports.SqliteError = SqliteError; module.exports.Authorization = Authorization; diff --git a/index.d.ts b/index.d.ts index dec5287..9cdec0a 100644 --- a/index.d.ts +++ b/index.d.ts @@ -80,6 +80,32 @@ export declare class Database { * - Legacy format: `{ [tableName: string]: 0 | 1 }` * - Full format: `{ rules: AuthRule[], defaultPolicy?: 0 | 1 | 2 }` * - `null` to remove the authorizer + * + * Pattern fields (`table`, `column`, `entity`) accept a plain string for + * exact matching, or `{ glob: "pattern" }` for glob matching with `*` and `?`. + * + * # Examples + * + * ```javascript + * const { Authorization, Action } = require('libsql'); + * + * // Legacy table-level allow/deny + * db.authorizer({ "users": Authorization.ALLOW }); + * + * // Rule-based with glob patterns + * db.authorizer({ + * rules: [ + * { action: Action.READ, table: "users", column: "password", policy: Authorization.IGNORE }, + * { action: Action.INSERT, table: { glob: "logs_*" }, policy: Authorization.ALLOW }, + * { action: Action.READ, policy: Authorization.ALLOW }, + * { action: Action.SELECT, policy: Authorization.ALLOW }, + * ], + * defaultPolicy: Authorization.DENY, + * }); + * + * // Remove authorizer + * db.authorizer(null); + * ``` */ authorizer(config: unknown): void /** @@ -173,6 +199,7 @@ export declare class Statement { } /** A raw iterator over rows. The JavaScript layer wraps this in a iterable. */ export declare class RowsIterator { + close(): void next(): Promise } export declare class Record { diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index cf5a205..1d54f71 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -423,6 +423,30 @@ test.serial("Query timeout option allows short-running query", async (t) => { db.close(); }); +test.serial("Stale timeout guard from exhausted iterator does not interrupt later queries", async (t) => { + t.timeout(30_000); + const [db] = await connect(":memory:", { defaultQueryTimeout: 1_000 }); + + // Insert test data. + await db.exec("CREATE TABLE t(x INTEGER)"); + const insert = await db.prepare("INSERT INTO t VALUES (?)"); + for (let i = 0; i < 2_000; i++) { + await insert.run(i); + } + + // Run many sequential queries via stmt.all() (which uses iterate() internally). + // Each query finishes well under the timeout, but if the RowsIterator's + // TimeoutGuard is not released until GC, stale guards will fire and + // interrupt unrelated later queries. + const stmt = await db.prepare("SELECT * FROM t ORDER BY x ASC"); + for (let i = 0; i < 150; i++) { + const rows = await stmt.all(); + t.is(rows.length, 2_000); + } + + db.close(); +}); + test.serial("Per-query timeout option interrupts long-running Statement.all()", async (t) => { const [db, errorType] = await connect(":memory:"); const stmt = await db.prepare( diff --git a/promise.js b/promise.js index 50214e6..95ca75b 100644 --- a/promise.js +++ b/promise.js @@ -354,18 +354,7 @@ class Statement { try { const { params, queryOptions } = splitBindParameters(bindParameters); const it = await this.stmt.iterate(params, queryOptions); - return { - next() { - return it.next(); - }, - [Symbol.asyncIterator]() { - return { - next() { - return it.next(); - } - }; - } - }; + return wrappedIter(it); } catch (err) { throw convertError(err); } @@ -380,11 +369,17 @@ class Statement { try { const result = []; const iterator = await this.iterate(...bindParameters); - let next; - while (!(next = await iterator.next()).done) { - result.push(next.value); + try { + let next; + while (!(next = await iterator.next()).done) { + result.push(next.value); + } + return result; + } finally { + if (typeof iterator.return === "function") { + await iterator.return(); + } } - return result; } catch (err) { throw convertError(err); } @@ -414,6 +409,26 @@ class Statement { } } +function wrappedIter(it) { + return { + next() { + return it.next(); + }, + return(value) { + if (typeof it.close === "function") { + it.close(); + } + return { + done: true, + value, + }; + }, + [Symbol.asyncIterator]() { + return this; + } + }; +} + module.exports = { Action, Authorization, diff --git a/src/lib.rs b/src/lib.rs index 21f4389..834fdc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ use std::{ str::FromStr, sync::{ atomic::{AtomicBool, Ordering}, - Arc, + Arc, Mutex, }, time::Duration, }; @@ -1391,7 +1391,7 @@ pub struct RowsIterator { safe_ints: bool, raw: bool, pluck: bool, - _timeout_guard: Option, + timeout_guard: Mutex>, } #[napi] @@ -1410,14 +1410,23 @@ impl RowsIterator { safe_ints, raw, pluck, - _timeout_guard: timeout_guard, + timeout_guard: Mutex::new(timeout_guard), } } #[napi] pub async fn next(&self) -> Result { let mut rows = self.rows.lock().await; - let row = rows.next().await.map_err(Error::from)?; + let row = match rows.next().await { + Ok(row) => row, + Err(err) => { + self.release_timeout_guard(); + return Err(Error::from(err).into()); + } + }; + if row.is_none() { + self.release_timeout_guard(); + } Ok(Record { row, column_names: self.column_names.clone(), @@ -1426,6 +1435,16 @@ impl RowsIterator { pluck: self.pluck, }) } + + #[napi] + pub fn close(&self) { + self.release_timeout_guard(); + } + + fn release_timeout_guard(&self) { + let mut guard = self.timeout_guard.lock().unwrap(); + guard.take(); + } } /// Retrieve next row from an iterator synchronously. Needed for better-sqlite3 API compatibility.