From e4a845e29bec9541f13a5bda811b95a37ee68514 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Tue, 31 Mar 2026 18:53:58 +0300 Subject: [PATCH] Add statement lifecycle management and explicit close support Track prepared statements via WeakRefs so Database.close() automatically closes all associated statements. Add explicit Statement.close() for manual resource management and Symbol.dispose support for the `using` syntax. Use-after-close now throws a clear TypeError instead of potentially undefined behavior. --- compat.js | 83 ++++++++++++++++--- docs/api.md | 5 ++ index.d.ts | 2 + integration-tests/tests/async.test.js | 24 ++++++ integration-tests/tests/sync.test.js | 28 +++++++ promise.js | 83 ++++++++++++++++--- src/lib.rs | 111 +++++++++++++++++--------- 7 files changed, 275 insertions(+), 61 deletions(-) diff --git a/compat.js b/compat.js index 374a741..a4a5363 100644 --- a/compat.js +++ b/compat.js @@ -48,6 +48,9 @@ function splitBindParameters(bindParameters) { return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; } +const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null; +const hasWeakRef = typeof WeakRef === "function"; + /** * Database represents a connection that can prepare and execute SQL statements. */ @@ -61,6 +64,8 @@ class Database { constructor(path, opts) { this.db = new NativeDb(path, opts); this.memory = this.db.memory + this._closed = false; + this._statements = hasWeakRef ? new Set() : null; const db = this.db; Object.defineProperties(this, { inTransaction: { @@ -95,7 +100,13 @@ class Database { prepare(sql) { try { const stmt = databasePrepareSync(this.db, sql); - return new Statement(stmt); + const wrappedStmt = new Statement(stmt, this); + if (this._statements != null) { + const statementRef = new WeakRef(wrappedStmt); + wrappedStmt._statementRef = statementRef; + this._statements.add(statementRef); + } + return wrappedStmt; } catch (err) { throw convertError(err); } @@ -215,6 +226,21 @@ class Database { * Closes the database connection. */ close() { + if (this._closed) { + return; + } + this._closed = true; + if (this._statements != null) { + for (const statementRef of Array.from(this._statements)) { + const statement = statementRef.deref(); + if (statement == null) { + this._statements.delete(statementRef); + continue; + } + statement.close(); + } + this._statements.clear(); + } this.db.close(); } @@ -240,8 +266,36 @@ class Database { * Statement represents a prepared SQL statement that can be executed. */ class Statement { - constructor(stmt) { + constructor(stmt, database) { this.stmt = stmt; + this.database = database; + this._statementRef = null; + this._closed = false; + } + + close() { + if (this._closed) { + return this; + } + this._closed = true; + if (this.database != null && this.database._statements != null && this._statementRef != null) { + this.database._statements.delete(this._statementRef); + } + if (this.database != null) { + this.database = null; + } + if (this.stmt != null) { + this.stmt.close(); + this.stmt = null; + } + return this; + } + + _getNativeStatement() { + if (this._closed || this.stmt == null) { + throw new TypeError("The database connection is not open"); + } + return this.stmt; } /** @@ -250,7 +304,7 @@ class Statement { * @param raw Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled. */ raw(raw) { - this.stmt.raw(raw); + this._getNativeStatement().raw(raw); return this; } @@ -260,7 +314,7 @@ class Statement { * @param pluckMode Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled. */ pluck(pluckMode) { - this.stmt.pluck(pluckMode); + this._getNativeStatement().pluck(pluckMode); return this; } @@ -270,12 +324,12 @@ class Statement { * @param timing Enable or disable query timing. If you don't pass the parameter, query timing is enabled. */ timing(timingMode) { - this.stmt.timing(timingMode); + this._getNativeStatement().timing(timingMode); return this; } get reader() { - return this.stmt.columns().length > 0; + return this._getNativeStatement().columns().length > 0; } /** @@ -284,7 +338,7 @@ class Statement { run(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return statementRunSync(this.stmt, params, queryOptions); + return statementRunSync(this._getNativeStatement(), params, queryOptions); } catch (err) { throw convertError(err); } @@ -298,7 +352,7 @@ class Statement { get(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return statementGetSync(this.stmt, params, queryOptions); + return statementGetSync(this._getNativeStatement(), params, queryOptions); } catch (err) { throw convertError(err); } @@ -312,7 +366,7 @@ class Statement { iterate(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - const it = statementIterateSync(this.stmt, params, queryOptions); + const it = statementIterateSync(this._getNativeStatement(), params, queryOptions); return { next: () => iteratorNextSync(it), [Symbol.iterator]() { @@ -349,7 +403,7 @@ class Statement { * Interrupts the statement. */ interrupt() { - this.stmt.interrupt(); + this._getNativeStatement().interrupt(); return this; } @@ -357,18 +411,23 @@ class Statement { * Returns the columns in the result set returned by this prepared statement. */ columns() { - return this.stmt.columns(); + return this._getNativeStatement().columns(); } /** * Toggle 64-bit integer support. */ safeIntegers(toggle) { - this.stmt.safeIntegers(toggle); + this._getNativeStatement().safeIntegers(toggle); return this; } } +if (symbolDispose != null) { + Database.prototype[symbolDispose] = Database.prototype.close; + Statement.prototype[symbolDispose] = Statement.prototype.close; +} + module.exports = Database; module.exports.SqliteError = SqliteError; module.exports.Authorization = Authorization; diff --git a/docs/api.md b/docs/api.md index 46116cb..c08acda 100644 --- a/docs/api.md +++ b/docs/api.md @@ -116,6 +116,7 @@ Cancel ongoing operations and make them return at earliest opportunity. ### close() ⇒ this Closes the database connection. +All statements created from this database are closed as well. # class Statement @@ -159,6 +160,10 @@ Executes the SQL statement and returns an iterator to the resulting rows. | bindParameters | array of objects | The bind parameters for executing the statement. | | queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | +### close() ⇒ this + +Closes the prepared statement and releases its resources. + ### pluck([toggleState]) ⇒ this This function is currently not supported. diff --git a/index.d.ts b/index.d.ts index 367e13e..0320c06 100644 --- a/index.d.ts +++ b/index.d.ts @@ -183,6 +183,8 @@ export declare class Statement { columns(): unknown[] safeIntegers(toggle?: boolean | undefined | null): this interrupt(): void + /** Closes the statement. */ + close(): void } /** A raw iterator over rows. The JavaScript layer wraps this in a iterable. */ export declare class RowsIterator { diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index cf5a205..01e4215 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -340,6 +340,30 @@ test.serial("Database.exec() after close()", async (t) => { }); }); +test.serial("Statement.get() after Database.close()", async (t) => { + const db = t.context.db; + const stmt = await db.prepare("SELECT 1"); + db.close(); + await t.throwsAsync(async () => { + await stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + +test.serial("Statement.close()", async (t) => { + const db = t.context.db; + const stmt = await db.prepare("SELECT 1"); + stmt.close(); + await t.throwsAsync(async () => { + await stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + test.serial("Database.interrupt()", async (t) => { const db = t.context.db; const stmt = await db.prepare("WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"); diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js index 8bbf18e..0dfe37e 100644 --- a/integration-tests/tests/sync.test.js +++ b/integration-tests/tests/sync.test.js @@ -433,6 +433,34 @@ test.serial("Database.exec() after close()", async (t) => { }); }); +test.serial("Statement.get() after Database.close()", async (t) => { + const db = t.context.db; + const stmt = db.prepare("SELECT 1"); + db.close(); + t.throws(() => { + stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + +test.serial("Statement.close()", async (t) => { + const db = t.context.db; + const stmt = db.prepare("SELECT 1"); + if (typeof stmt.close !== "function") { + t.pass(); + return; + } + stmt.close(); + t.throws(() => { + stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + test.serial("Timeout option", async (t) => { const timeout = 1000; const path = genDatabaseFilename(); diff --git a/promise.js b/promise.js index f364013..d1013b1 100644 --- a/promise.js +++ b/promise.js @@ -61,6 +61,9 @@ function splitBindParameters(bindParameters) { return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; } +const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null; +const hasWeakRef = typeof WeakRef === "function"; + /** * Creates a new database connection. * @@ -85,6 +88,8 @@ class Database { constructor(db) { this.db = db; this.memory = this.db.memory + this._closed = false; + this._statements = hasWeakRef ? new Set() : null; /** @type boolean */ this.inTransaction; @@ -118,7 +123,13 @@ class Database { async prepare(sql) { try { const stmt = await this.db.prepare(sql); - return new Statement(stmt); + const wrappedStmt = new Statement(stmt, this); + if (this._statements != null) { + const statementRef = new WeakRef(wrappedStmt); + wrappedStmt._statementRef = statementRef; + this._statements.add(statementRef); + } + return wrappedStmt; } catch (err) { throw convertError(err); } @@ -256,6 +267,21 @@ class Database { * Closes the database connection. */ close() { + if (this._closed) { + return; + } + this._closed = true; + if (this._statements != null) { + for (const statementRef of Array.from(this._statements)) { + const statement = statementRef.deref(); + if (statement == null) { + this._statements.delete(statementRef); + continue; + } + statement.close(); + } + this._statements.clear(); + } this.db.close(); } @@ -285,8 +311,36 @@ class Statement { /** * @param {NativeStatement} stmt */ - constructor(stmt) { + constructor(stmt, database) { this.stmt = stmt; + this.database = database; + this._statementRef = null; + this._closed = false; + } + + close() { + if (this._closed) { + return this; + } + this._closed = true; + if (this.database != null && this.database._statements != null && this._statementRef != null) { + this.database._statements.delete(this._statementRef); + } + if (this.database != null) { + this.database = null; + } + if (this.stmt != null) { + this.stmt.close(); + this.stmt = null; + } + return this; + } + + _getNativeStatement() { + if (this._closed || this.stmt == null) { + throw new TypeError("The database connection is not open"); + } + return this.stmt; } /** @@ -295,7 +349,7 @@ class Statement { * @param {boolean} [raw] - Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled. */ raw(raw) { - this.stmt.raw(raw); + this._getNativeStatement().raw(raw); return this; } @@ -305,7 +359,7 @@ class Statement { * @param {boolean} [pluckMode] - Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled. */ pluck(pluckMode) { - this.stmt.pluck(pluckMode); + this._getNativeStatement().pluck(pluckMode); return this; } @@ -315,12 +369,12 @@ class Statement { * @param {boolean} [timingMode] - Enable or disable query timing. If you don't pass the parameter, query timing is enabled. */ timing(timingMode) { - this.stmt.timing(timingMode); + this._getNativeStatement().timing(timingMode); return this; } get reader() { - return this.stmt.columns().length > 0; + return this._getNativeStatement().columns().length > 0; } /** @@ -329,7 +383,7 @@ class Statement { async run(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return await this.stmt.run(params, queryOptions); + return await this._getNativeStatement().run(params, queryOptions); } catch (err) { throw convertError(err); } @@ -343,7 +397,7 @@ class Statement { async get(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return await this.stmt.get(params, queryOptions); + return await this._getNativeStatement().get(params, queryOptions); } catch (err) { throw convertError(err); } @@ -357,7 +411,7 @@ class Statement { async iterate(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - const it = await this.stmt.iterate(params, queryOptions); + const it = await this._getNativeStatement().iterate(params, queryOptions); return { next() { return it.next(); @@ -398,7 +452,7 @@ class Statement { * Interrupts the statement. */ interrupt() { - this.stmt.interrupt(); + this._getNativeStatement().interrupt(); return this; } @@ -406,18 +460,23 @@ class Statement { * Returns the columns in the result set returned by this prepared statement. */ columns() { - return this.stmt.columns(); + return this._getNativeStatement().columns(); } /** * Toggle 64-bit integer support. */ safeIntegers(toggle) { - this.stmt.safeIntegers(toggle); + this._getNativeStatement().safeIntegers(toggle); return this; } } +if (symbolDispose != null) { + Database.prototype[symbolDispose] = Database.prototype.close; + Statement.prototype[symbolDispose] = Statement.prototype.close; +} + module.exports = { Authorization, Database, diff --git a/src/lib.rs b/src/lib.rs index 61e1581..3af0d22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -664,6 +664,14 @@ fn throw_database_closed_error(env: &Env) -> napi::Error { err } +fn throw_not_open_error() -> napi::Error { + throw_sqlite_error( + "The database connection is not open".to_string(), + "SQLITE_NOTOPEN".to_string(), + 0, + ) +} + fn query_timeout_duration(timeout_ms: f64) -> Option { if timeout_ms.is_finite() && timeout_ms > 0.0 { Some(Duration::from_millis(timeout_ms as u64)) @@ -676,9 +684,9 @@ fn query_timeout_duration(timeout_ms: f64) -> Option { #[napi] pub struct Statement { // The libSQL connection instance. - conn: Arc, + conn: Option>, // The libSQL statement instance. - stmt: Arc, + stmt: Option>, // The column names. column_names: Vec, // The access mode. @@ -712,8 +720,8 @@ impl Statement { .collect(); let stmt = Arc::new(stmt); Self { - conn, - stmt, + conn: Some(conn), + stmt: Some(stmt), column_names, mode, query_timeout, @@ -733,13 +741,13 @@ impl Statement { params: Option, query_options: Option, ) -> Result { - self.stmt.reset(); - let params = map_params(&self.stmt, params)?; - let total_changes_before = self.conn.total_changes(); + let stmt = self.statement()?; + let conn = self.connection()?; + stmt.reset(); + let params = map_params(stmt.as_ref(), params)?; + let total_changes_before = conn.total_changes(); let start = std::time::Instant::now(); - let stmt = self.stmt.clone(); - let conn = self.conn.clone(); - let guard = self.start_timeout_guard(query_options); + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let _guard = guard; @@ -774,13 +782,14 @@ impl Statement { params: Option, query_options: Option, ) -> Result { + let stmt = self.statement()?; + let conn = self.connection()?; let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); let timed = self.mode.timing.load(Ordering::SeqCst); - let params = map_params(&self.stmt, params)?; - let stmt = self.stmt.clone(); + let params = map_params(stmt.as_ref(), params)?; let column_names = self.column_names.clone(); let start = if timed { @@ -790,7 +799,7 @@ impl Statement { }; let stmt_fut = stmt.clone(); - let guard = self.start_timeout_guard(query_options); + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let _guard = guard; let mut rows = stmt_fut.query(params).await.map_err(Error::from)?; @@ -854,14 +863,14 @@ impl Statement { params: Option, query_options: Option, ) -> Result { + let stmt = self.statement()?; + let conn = self.connection()?; let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); - let stmt = self.stmt.clone(); stmt.reset(); - let params = map_params(&stmt, params).unwrap(); - let stmt = self.stmt.clone(); - let guard = self.start_timeout_guard(query_options); + let params = map_params(stmt.as_ref(), params)?; + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let rows = stmt.query(params).await.map_err(Error::from)?; Ok::<_, napi::Error>(rows) @@ -881,7 +890,8 @@ impl Statement { #[napi] pub fn raw(&self, raw: Option) -> Result<&Self> { - let returns_data = !self.stmt.columns().is_empty(); + let stmt = self.statement()?; + let returns_data = !stmt.columns().is_empty(); if !returns_data { return Err(napi::Error::from_reason( "The raw() method is only for statements that return data", @@ -909,7 +919,8 @@ impl Statement { #[napi] pub fn columns(&self, env: Env) -> Result { - let columns = self.stmt.columns(); + let stmt = self.statement()?; + let columns = stmt.columns(); let mut js_array = env.create_array(columns.len() as u32)?; for (i, col) in columns.iter().enumerate() { let mut js_obj = env.create_object()?; @@ -953,12 +964,29 @@ impl Statement { #[napi] pub fn interrupt(&self) -> Result<()> { - self.stmt.interrupt().map_err(Error::from)?; + let stmt = self.statement()?; + stmt.interrupt().map_err(Error::from)?; + Ok(()) + } + + /// Closes the statement. + #[napi] + pub fn close(&mut self) -> Result<()> { + self.stmt = None; + self.conn = None; Ok(()) } } impl Statement { + fn statement(&self) -> Result> { + self.stmt.clone().ok_or_else(throw_not_open_error) + } + + fn connection(&self) -> Result> { + self.conn.clone().ok_or_else(throw_not_open_error) + } + fn resolve_query_timeout(&self, query_options: Option) -> Option { match query_options.and_then(|o| o.queryTimeout) { Some(timeout_ms) => query_timeout_duration(timeout_ms), @@ -966,9 +994,13 @@ impl Statement { } } - fn start_timeout_guard(&self, query_options: Option) -> Option { + fn start_timeout_guard( + &self, + conn: &Arc, + query_options: Option, + ) -> Option { self.resolve_query_timeout(query_options) - .map(|t| self.timeout_manager.register(&self.conn, t)) + .map(|t| self.timeout_manager.register(conn, t)) } } @@ -980,6 +1012,8 @@ pub fn statement_get_sync( params: Option, query_options: Option, ) -> Result { + let inner_stmt = stmt.statement()?; + let conn = stmt.connection()?; let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); let pluck = stmt.mode.pluck.load(Ordering::SeqCst); @@ -992,10 +1026,10 @@ pub fn statement_get_sync( }; let rt = runtime()?; - let _guard = stmt.start_timeout_guard(query_options); + let _guard = stmt.start_timeout_guard(&conn, query_options); rt.block_on(async move { - let params = map_params(&stmt.stmt, params)?; - let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?; + let params = map_params(inner_stmt.as_ref(), params)?; + let mut rows = inner_stmt.query(params).await.map_err(Error::from)?; let row = rows.next().await.map_err(Error::from)?; let duration: Option = start.map(|start| start.elapsed().as_secs_f64()); let result = Statement::get_internal( @@ -1007,7 +1041,7 @@ pub fn statement_get_sync( pluck, duration, ); - stmt.stmt.reset(); + inner_stmt.reset(); result }) } @@ -1019,21 +1053,23 @@ pub fn statement_run_sync( params: Option, query_options: Option, ) -> Result { - stmt.stmt.reset(); + let inner_stmt = stmt.statement()?; + let conn = stmt.connection()?; + inner_stmt.reset(); let rt = runtime()?; - let _guard = stmt.start_timeout_guard(query_options); + let _guard = stmt.start_timeout_guard(&conn, query_options); rt.block_on(async move { - let params = map_params(&stmt.stmt, params)?; - let total_changes_before = stmt.conn.total_changes(); + let params = map_params(inner_stmt.as_ref(), params)?; + let total_changes_before = conn.total_changes(); let start = std::time::Instant::now(); - stmt.stmt.run(params).await.map_err(Error::from)?; - let changes = if stmt.conn.total_changes() == total_changes_before { + inner_stmt.run(params).await.map_err(Error::from)?; + let changes = if conn.total_changes() == total_changes_before { 0 } else { - stmt.conn.changes() + conn.changes() }; - let last_insert_row_id = stmt.conn.last_insert_rowid(); + let last_insert_row_id = conn.last_insert_rowid(); let duration = start.elapsed().as_secs_f64(); Ok(RunResult { changes: changes as f64, @@ -1051,14 +1087,15 @@ pub fn statement_iterate_sync( query_options: Option, ) -> Result { let rt = runtime()?; + let conn = stmt.connection()?; let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); let pluck = stmt.mode.pluck.load(Ordering::SeqCst); - let guard = stmt.start_timeout_guard(query_options); - let inner_stmt = stmt.stmt.clone(); + let guard = stmt.start_timeout_guard(&conn, query_options); + let inner_stmt = stmt.statement()?; let (rows, column_names) = rt.block_on(async move { inner_stmt.reset(); - let params = map_params(&inner_stmt, params)?; + let params = map_params(inner_stmt.as_ref(), params)?; let rows = inner_stmt.query(params).await.map_err(Error::from)?; let mut column_names = Vec::new(); for i in 0..rows.column_count() {