diff --git a/compat.js b/compat.js index 374a741..a4a5363 100644 --- a/compat.js +++ b/compat.js @@ -48,6 +48,9 @@ function splitBindParameters(bindParameters) { return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; } +const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null; +const hasWeakRef = typeof WeakRef === "function"; + /** * Database represents a connection that can prepare and execute SQL statements. */ @@ -61,6 +64,8 @@ class Database { constructor(path, opts) { this.db = new NativeDb(path, opts); this.memory = this.db.memory + this._closed = false; + this._statements = hasWeakRef ? new Set() : null; const db = this.db; Object.defineProperties(this, { inTransaction: { @@ -95,7 +100,13 @@ class Database { prepare(sql) { try { const stmt = databasePrepareSync(this.db, sql); - return new Statement(stmt); + const wrappedStmt = new Statement(stmt, this); + if (this._statements != null) { + const statementRef = new WeakRef(wrappedStmt); + wrappedStmt._statementRef = statementRef; + this._statements.add(statementRef); + } + return wrappedStmt; } catch (err) { throw convertError(err); } @@ -215,6 +226,21 @@ class Database { * Closes the database connection. */ close() { + if (this._closed) { + return; + } + this._closed = true; + if (this._statements != null) { + for (const statementRef of Array.from(this._statements)) { + const statement = statementRef.deref(); + if (statement == null) { + this._statements.delete(statementRef); + continue; + } + statement.close(); + } + this._statements.clear(); + } this.db.close(); } @@ -240,8 +266,36 @@ class Database { * Statement represents a prepared SQL statement that can be executed. */ class Statement { - constructor(stmt) { + constructor(stmt, database) { this.stmt = stmt; + this.database = database; + this._statementRef = null; + this._closed = false; + } + + close() { + if (this._closed) { + return this; + } + this._closed = true; + if (this.database != null && this.database._statements != null && this._statementRef != null) { + this.database._statements.delete(this._statementRef); + } + if (this.database != null) { + this.database = null; + } + if (this.stmt != null) { + this.stmt.close(); + this.stmt = null; + } + return this; + } + + _getNativeStatement() { + if (this._closed || this.stmt == null) { + throw new TypeError("The database connection is not open"); + } + return this.stmt; } /** @@ -250,7 +304,7 @@ class Statement { * @param raw Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled. */ raw(raw) { - this.stmt.raw(raw); + this._getNativeStatement().raw(raw); return this; } @@ -260,7 +314,7 @@ class Statement { * @param pluckMode Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled. */ pluck(pluckMode) { - this.stmt.pluck(pluckMode); + this._getNativeStatement().pluck(pluckMode); return this; } @@ -270,12 +324,12 @@ class Statement { * @param timing Enable or disable query timing. If you don't pass the parameter, query timing is enabled. */ timing(timingMode) { - this.stmt.timing(timingMode); + this._getNativeStatement().timing(timingMode); return this; } get reader() { - return this.stmt.columns().length > 0; + return this._getNativeStatement().columns().length > 0; } /** @@ -284,7 +338,7 @@ class Statement { run(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return statementRunSync(this.stmt, params, queryOptions); + return statementRunSync(this._getNativeStatement(), params, queryOptions); } catch (err) { throw convertError(err); } @@ -298,7 +352,7 @@ class Statement { get(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return statementGetSync(this.stmt, params, queryOptions); + return statementGetSync(this._getNativeStatement(), params, queryOptions); } catch (err) { throw convertError(err); } @@ -312,7 +366,7 @@ class Statement { iterate(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - const it = statementIterateSync(this.stmt, params, queryOptions); + const it = statementIterateSync(this._getNativeStatement(), params, queryOptions); return { next: () => iteratorNextSync(it), [Symbol.iterator]() { @@ -349,7 +403,7 @@ class Statement { * Interrupts the statement. */ interrupt() { - this.stmt.interrupt(); + this._getNativeStatement().interrupt(); return this; } @@ -357,18 +411,23 @@ class Statement { * Returns the columns in the result set returned by this prepared statement. */ columns() { - return this.stmt.columns(); + return this._getNativeStatement().columns(); } /** * Toggle 64-bit integer support. */ safeIntegers(toggle) { - this.stmt.safeIntegers(toggle); + this._getNativeStatement().safeIntegers(toggle); return this; } } +if (symbolDispose != null) { + Database.prototype[symbolDispose] = Database.prototype.close; + Statement.prototype[symbolDispose] = Statement.prototype.close; +} + module.exports = Database; module.exports.SqliteError = SqliteError; module.exports.Authorization = Authorization; diff --git a/docs/api.md b/docs/api.md index 46116cb..c08acda 100644 --- a/docs/api.md +++ b/docs/api.md @@ -116,6 +116,7 @@ Cancel ongoing operations and make them return at earliest opportunity. ### close() ⇒ this Closes the database connection. +All statements created from this database are closed as well. # class Statement @@ -159,6 +160,10 @@ Executes the SQL statement and returns an iterator to the resulting rows. | bindParameters | array of objects | The bind parameters for executing the statement. | | queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). | +### close() ⇒ this + +Closes the prepared statement and releases its resources. + ### pluck([toggleState]) ⇒ this This function is currently not supported. diff --git a/index.d.ts b/index.d.ts index 367e13e..0320c06 100644 --- a/index.d.ts +++ b/index.d.ts @@ -183,6 +183,8 @@ export declare class Statement { columns(): unknown[] safeIntegers(toggle?: boolean | undefined | null): this interrupt(): void + /** Closes the statement. */ + close(): void } /** A raw iterator over rows. The JavaScript layer wraps this in a iterable. */ export declare class RowsIterator { diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js index cf5a205..01e4215 100644 --- a/integration-tests/tests/async.test.js +++ b/integration-tests/tests/async.test.js @@ -340,6 +340,30 @@ test.serial("Database.exec() after close()", async (t) => { }); }); +test.serial("Statement.get() after Database.close()", async (t) => { + const db = t.context.db; + const stmt = await db.prepare("SELECT 1"); + db.close(); + await t.throwsAsync(async () => { + await stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + +test.serial("Statement.close()", async (t) => { + const db = t.context.db; + const stmt = await db.prepare("SELECT 1"); + stmt.close(); + await t.throwsAsync(async () => { + await stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + test.serial("Database.interrupt()", async (t) => { const db = t.context.db; const stmt = await db.prepare("WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"); diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js index 8bbf18e..0dfe37e 100644 --- a/integration-tests/tests/sync.test.js +++ b/integration-tests/tests/sync.test.js @@ -433,6 +433,34 @@ test.serial("Database.exec() after close()", async (t) => { }); }); +test.serial("Statement.get() after Database.close()", async (t) => { + const db = t.context.db; + const stmt = db.prepare("SELECT 1"); + db.close(); + t.throws(() => { + stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + +test.serial("Statement.close()", async (t) => { + const db = t.context.db; + const stmt = db.prepare("SELECT 1"); + if (typeof stmt.close !== "function") { + t.pass(); + return; + } + stmt.close(); + t.throws(() => { + stmt.get(); + }, { + instanceOf: TypeError, + message: "The database connection is not open" + }); +}); + test.serial("Timeout option", async (t) => { const timeout = 1000; const path = genDatabaseFilename(); diff --git a/promise.js b/promise.js index f364013..d1013b1 100644 --- a/promise.js +++ b/promise.js @@ -61,6 +61,9 @@ function splitBindParameters(bindParameters) { return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined }; } +const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null; +const hasWeakRef = typeof WeakRef === "function"; + /** * Creates a new database connection. * @@ -85,6 +88,8 @@ class Database { constructor(db) { this.db = db; this.memory = this.db.memory + this._closed = false; + this._statements = hasWeakRef ? new Set() : null; /** @type boolean */ this.inTransaction; @@ -118,7 +123,13 @@ class Database { async prepare(sql) { try { const stmt = await this.db.prepare(sql); - return new Statement(stmt); + const wrappedStmt = new Statement(stmt, this); + if (this._statements != null) { + const statementRef = new WeakRef(wrappedStmt); + wrappedStmt._statementRef = statementRef; + this._statements.add(statementRef); + } + return wrappedStmt; } catch (err) { throw convertError(err); } @@ -256,6 +267,21 @@ class Database { * Closes the database connection. */ close() { + if (this._closed) { + return; + } + this._closed = true; + if (this._statements != null) { + for (const statementRef of Array.from(this._statements)) { + const statement = statementRef.deref(); + if (statement == null) { + this._statements.delete(statementRef); + continue; + } + statement.close(); + } + this._statements.clear(); + } this.db.close(); } @@ -285,8 +311,36 @@ class Statement { /** * @param {NativeStatement} stmt */ - constructor(stmt) { + constructor(stmt, database) { this.stmt = stmt; + this.database = database; + this._statementRef = null; + this._closed = false; + } + + close() { + if (this._closed) { + return this; + } + this._closed = true; + if (this.database != null && this.database._statements != null && this._statementRef != null) { + this.database._statements.delete(this._statementRef); + } + if (this.database != null) { + this.database = null; + } + if (this.stmt != null) { + this.stmt.close(); + this.stmt = null; + } + return this; + } + + _getNativeStatement() { + if (this._closed || this.stmt == null) { + throw new TypeError("The database connection is not open"); + } + return this.stmt; } /** @@ -295,7 +349,7 @@ class Statement { * @param {boolean} [raw] - Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled. */ raw(raw) { - this.stmt.raw(raw); + this._getNativeStatement().raw(raw); return this; } @@ -305,7 +359,7 @@ class Statement { * @param {boolean} [pluckMode] - Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled. */ pluck(pluckMode) { - this.stmt.pluck(pluckMode); + this._getNativeStatement().pluck(pluckMode); return this; } @@ -315,12 +369,12 @@ class Statement { * @param {boolean} [timingMode] - Enable or disable query timing. If you don't pass the parameter, query timing is enabled. */ timing(timingMode) { - this.stmt.timing(timingMode); + this._getNativeStatement().timing(timingMode); return this; } get reader() { - return this.stmt.columns().length > 0; + return this._getNativeStatement().columns().length > 0; } /** @@ -329,7 +383,7 @@ class Statement { async run(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return await this.stmt.run(params, queryOptions); + return await this._getNativeStatement().run(params, queryOptions); } catch (err) { throw convertError(err); } @@ -343,7 +397,7 @@ class Statement { async get(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - return await this.stmt.get(params, queryOptions); + return await this._getNativeStatement().get(params, queryOptions); } catch (err) { throw convertError(err); } @@ -357,7 +411,7 @@ class Statement { async iterate(...bindParameters) { try { const { params, queryOptions } = splitBindParameters(bindParameters); - const it = await this.stmt.iterate(params, queryOptions); + const it = await this._getNativeStatement().iterate(params, queryOptions); return { next() { return it.next(); @@ -398,7 +452,7 @@ class Statement { * Interrupts the statement. */ interrupt() { - this.stmt.interrupt(); + this._getNativeStatement().interrupt(); return this; } @@ -406,18 +460,23 @@ class Statement { * Returns the columns in the result set returned by this prepared statement. */ columns() { - return this.stmt.columns(); + return this._getNativeStatement().columns(); } /** * Toggle 64-bit integer support. */ safeIntegers(toggle) { - this.stmt.safeIntegers(toggle); + this._getNativeStatement().safeIntegers(toggle); return this; } } +if (symbolDispose != null) { + Database.prototype[symbolDispose] = Database.prototype.close; + Statement.prototype[symbolDispose] = Statement.prototype.close; +} + module.exports = { Authorization, Database, diff --git a/src/lib.rs b/src/lib.rs index 61e1581..3af0d22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -664,6 +664,14 @@ fn throw_database_closed_error(env: &Env) -> napi::Error { err } +fn throw_not_open_error() -> napi::Error { + throw_sqlite_error( + "The database connection is not open".to_string(), + "SQLITE_NOTOPEN".to_string(), + 0, + ) +} + fn query_timeout_duration(timeout_ms: f64) -> Option { if timeout_ms.is_finite() && timeout_ms > 0.0 { Some(Duration::from_millis(timeout_ms as u64)) @@ -676,9 +684,9 @@ fn query_timeout_duration(timeout_ms: f64) -> Option { #[napi] pub struct Statement { // The libSQL connection instance. - conn: Arc, + conn: Option>, // The libSQL statement instance. - stmt: Arc, + stmt: Option>, // The column names. column_names: Vec, // The access mode. @@ -712,8 +720,8 @@ impl Statement { .collect(); let stmt = Arc::new(stmt); Self { - conn, - stmt, + conn: Some(conn), + stmt: Some(stmt), column_names, mode, query_timeout, @@ -733,13 +741,13 @@ impl Statement { params: Option, query_options: Option, ) -> Result { - self.stmt.reset(); - let params = map_params(&self.stmt, params)?; - let total_changes_before = self.conn.total_changes(); + let stmt = self.statement()?; + let conn = self.connection()?; + stmt.reset(); + let params = map_params(stmt.as_ref(), params)?; + let total_changes_before = conn.total_changes(); let start = std::time::Instant::now(); - let stmt = self.stmt.clone(); - let conn = self.conn.clone(); - let guard = self.start_timeout_guard(query_options); + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let _guard = guard; @@ -774,13 +782,14 @@ impl Statement { params: Option, query_options: Option, ) -> Result { + let stmt = self.statement()?; + let conn = self.connection()?; let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); let timed = self.mode.timing.load(Ordering::SeqCst); - let params = map_params(&self.stmt, params)?; - let stmt = self.stmt.clone(); + let params = map_params(stmt.as_ref(), params)?; let column_names = self.column_names.clone(); let start = if timed { @@ -790,7 +799,7 @@ impl Statement { }; let stmt_fut = stmt.clone(); - let guard = self.start_timeout_guard(query_options); + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let _guard = guard; let mut rows = stmt_fut.query(params).await.map_err(Error::from)?; @@ -854,14 +863,14 @@ impl Statement { params: Option, query_options: Option, ) -> Result { + let stmt = self.statement()?; + let conn = self.connection()?; let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst); let raw = self.mode.raw.load(Ordering::SeqCst); let pluck = self.mode.pluck.load(Ordering::SeqCst); - let stmt = self.stmt.clone(); stmt.reset(); - let params = map_params(&stmt, params).unwrap(); - let stmt = self.stmt.clone(); - let guard = self.start_timeout_guard(query_options); + let params = map_params(stmt.as_ref(), params)?; + let guard = self.start_timeout_guard(&conn, query_options); let future = async move { let rows = stmt.query(params).await.map_err(Error::from)?; Ok::<_, napi::Error>(rows) @@ -881,7 +890,8 @@ impl Statement { #[napi] pub fn raw(&self, raw: Option) -> Result<&Self> { - let returns_data = !self.stmt.columns().is_empty(); + let stmt = self.statement()?; + let returns_data = !stmt.columns().is_empty(); if !returns_data { return Err(napi::Error::from_reason( "The raw() method is only for statements that return data", @@ -909,7 +919,8 @@ impl Statement { #[napi] pub fn columns(&self, env: Env) -> Result { - let columns = self.stmt.columns(); + let stmt = self.statement()?; + let columns = stmt.columns(); let mut js_array = env.create_array(columns.len() as u32)?; for (i, col) in columns.iter().enumerate() { let mut js_obj = env.create_object()?; @@ -953,12 +964,29 @@ impl Statement { #[napi] pub fn interrupt(&self) -> Result<()> { - self.stmt.interrupt().map_err(Error::from)?; + let stmt = self.statement()?; + stmt.interrupt().map_err(Error::from)?; + Ok(()) + } + + /// Closes the statement. + #[napi] + pub fn close(&mut self) -> Result<()> { + self.stmt = None; + self.conn = None; Ok(()) } } impl Statement { + fn statement(&self) -> Result> { + self.stmt.clone().ok_or_else(throw_not_open_error) + } + + fn connection(&self) -> Result> { + self.conn.clone().ok_or_else(throw_not_open_error) + } + fn resolve_query_timeout(&self, query_options: Option) -> Option { match query_options.and_then(|o| o.queryTimeout) { Some(timeout_ms) => query_timeout_duration(timeout_ms), @@ -966,9 +994,13 @@ impl Statement { } } - fn start_timeout_guard(&self, query_options: Option) -> Option { + fn start_timeout_guard( + &self, + conn: &Arc, + query_options: Option, + ) -> Option { self.resolve_query_timeout(query_options) - .map(|t| self.timeout_manager.register(&self.conn, t)) + .map(|t| self.timeout_manager.register(conn, t)) } } @@ -980,6 +1012,8 @@ pub fn statement_get_sync( params: Option, query_options: Option, ) -> Result { + let inner_stmt = stmt.statement()?; + let conn = stmt.connection()?; let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); let pluck = stmt.mode.pluck.load(Ordering::SeqCst); @@ -992,10 +1026,10 @@ pub fn statement_get_sync( }; let rt = runtime()?; - let _guard = stmt.start_timeout_guard(query_options); + let _guard = stmt.start_timeout_guard(&conn, query_options); rt.block_on(async move { - let params = map_params(&stmt.stmt, params)?; - let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?; + let params = map_params(inner_stmt.as_ref(), params)?; + let mut rows = inner_stmt.query(params).await.map_err(Error::from)?; let row = rows.next().await.map_err(Error::from)?; let duration: Option = start.map(|start| start.elapsed().as_secs_f64()); let result = Statement::get_internal( @@ -1007,7 +1041,7 @@ pub fn statement_get_sync( pluck, duration, ); - stmt.stmt.reset(); + inner_stmt.reset(); result }) } @@ -1019,21 +1053,23 @@ pub fn statement_run_sync( params: Option, query_options: Option, ) -> Result { - stmt.stmt.reset(); + let inner_stmt = stmt.statement()?; + let conn = stmt.connection()?; + inner_stmt.reset(); let rt = runtime()?; - let _guard = stmt.start_timeout_guard(query_options); + let _guard = stmt.start_timeout_guard(&conn, query_options); rt.block_on(async move { - let params = map_params(&stmt.stmt, params)?; - let total_changes_before = stmt.conn.total_changes(); + let params = map_params(inner_stmt.as_ref(), params)?; + let total_changes_before = conn.total_changes(); let start = std::time::Instant::now(); - stmt.stmt.run(params).await.map_err(Error::from)?; - let changes = if stmt.conn.total_changes() == total_changes_before { + inner_stmt.run(params).await.map_err(Error::from)?; + let changes = if conn.total_changes() == total_changes_before { 0 } else { - stmt.conn.changes() + conn.changes() }; - let last_insert_row_id = stmt.conn.last_insert_rowid(); + let last_insert_row_id = conn.last_insert_rowid(); let duration = start.elapsed().as_secs_f64(); Ok(RunResult { changes: changes as f64, @@ -1051,14 +1087,15 @@ pub fn statement_iterate_sync( query_options: Option, ) -> Result { let rt = runtime()?; + let conn = stmt.connection()?; let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst); let raw = stmt.mode.raw.load(Ordering::SeqCst); let pluck = stmt.mode.pluck.load(Ordering::SeqCst); - let guard = stmt.start_timeout_guard(query_options); - let inner_stmt = stmt.stmt.clone(); + let guard = stmt.start_timeout_guard(&conn, query_options); + let inner_stmt = stmt.statement()?; let (rows, column_names) = rt.block_on(async move { inner_stmt.reset(); - let params = map_params(&inner_stmt, params)?; + let params = map_params(inner_stmt.as_ref(), params)?; let rows = inner_stmt.query(params).await.map_err(Error::from)?; let mut column_names = Vec::new(); for i in 0..rows.column_count() {