diff --git a/compat.js b/compat.js
index 374a741..a4a5363 100644
--- a/compat.js
+++ b/compat.js
@@ -48,6 +48,9 @@ function splitBindParameters(bindParameters) {
return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined };
}
+const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null;
+const hasWeakRef = typeof WeakRef === "function";
+
/**
* Database represents a connection that can prepare and execute SQL statements.
*/
@@ -61,6 +64,8 @@ class Database {
constructor(path, opts) {
this.db = new NativeDb(path, opts);
this.memory = this.db.memory
+ this._closed = false;
+ this._statements = hasWeakRef ? new Set() : null;
const db = this.db;
Object.defineProperties(this, {
inTransaction: {
@@ -95,7 +100,13 @@ class Database {
prepare(sql) {
try {
const stmt = databasePrepareSync(this.db, sql);
- return new Statement(stmt);
+ const wrappedStmt = new Statement(stmt, this);
+ if (this._statements != null) {
+ const statementRef = new WeakRef(wrappedStmt);
+ wrappedStmt._statementRef = statementRef;
+ this._statements.add(statementRef);
+ }
+ return wrappedStmt;
} catch (err) {
throw convertError(err);
}
@@ -215,6 +226,21 @@ class Database {
* Closes the database connection.
*/
close() {
+ if (this._closed) {
+ return;
+ }
+ this._closed = true;
+ if (this._statements != null) {
+ for (const statementRef of Array.from(this._statements)) {
+ const statement = statementRef.deref();
+ if (statement == null) {
+ this._statements.delete(statementRef);
+ continue;
+ }
+ statement.close();
+ }
+ this._statements.clear();
+ }
this.db.close();
}
@@ -240,8 +266,36 @@ class Database {
* Statement represents a prepared SQL statement that can be executed.
*/
class Statement {
- constructor(stmt) {
+ constructor(stmt, database) {
this.stmt = stmt;
+ this.database = database;
+ this._statementRef = null;
+ this._closed = false;
+ }
+
+ close() {
+ if (this._closed) {
+ return this;
+ }
+ this._closed = true;
+ if (this.database != null && this.database._statements != null && this._statementRef != null) {
+ this.database._statements.delete(this._statementRef);
+ }
+ if (this.database != null) {
+ this.database = null;
+ }
+ if (this.stmt != null) {
+ this.stmt.close();
+ this.stmt = null;
+ }
+ return this;
+ }
+
+ _getNativeStatement() {
+ if (this._closed || this.stmt == null) {
+ throw new TypeError("The database connection is not open");
+ }
+ return this.stmt;
}
/**
@@ -250,7 +304,7 @@ class Statement {
* @param raw Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled.
*/
raw(raw) {
- this.stmt.raw(raw);
+ this._getNativeStatement().raw(raw);
return this;
}
@@ -260,7 +314,7 @@ class Statement {
* @param pluckMode Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled.
*/
pluck(pluckMode) {
- this.stmt.pluck(pluckMode);
+ this._getNativeStatement().pluck(pluckMode);
return this;
}
@@ -270,12 +324,12 @@ class Statement {
* @param timing Enable or disable query timing. If you don't pass the parameter, query timing is enabled.
*/
timing(timingMode) {
- this.stmt.timing(timingMode);
+ this._getNativeStatement().timing(timingMode);
return this;
}
get reader() {
- return this.stmt.columns().length > 0;
+ return this._getNativeStatement().columns().length > 0;
}
/**
@@ -284,7 +338,7 @@ class Statement {
run(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- return statementRunSync(this.stmt, params, queryOptions);
+ return statementRunSync(this._getNativeStatement(), params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -298,7 +352,7 @@ class Statement {
get(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- return statementGetSync(this.stmt, params, queryOptions);
+ return statementGetSync(this._getNativeStatement(), params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -312,7 +366,7 @@ class Statement {
iterate(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- const it = statementIterateSync(this.stmt, params, queryOptions);
+ const it = statementIterateSync(this._getNativeStatement(), params, queryOptions);
return {
next: () => iteratorNextSync(it),
[Symbol.iterator]() {
@@ -349,7 +403,7 @@ class Statement {
* Interrupts the statement.
*/
interrupt() {
- this.stmt.interrupt();
+ this._getNativeStatement().interrupt();
return this;
}
@@ -357,18 +411,23 @@ class Statement {
* Returns the columns in the result set returned by this prepared statement.
*/
columns() {
- return this.stmt.columns();
+ return this._getNativeStatement().columns();
}
/**
* Toggle 64-bit integer support.
*/
safeIntegers(toggle) {
- this.stmt.safeIntegers(toggle);
+ this._getNativeStatement().safeIntegers(toggle);
return this;
}
}
+if (symbolDispose != null) {
+ Database.prototype[symbolDispose] = Database.prototype.close;
+ Statement.prototype[symbolDispose] = Statement.prototype.close;
+}
+
module.exports = Database;
module.exports.SqliteError = SqliteError;
module.exports.Authorization = Authorization;
diff --git a/docs/api.md b/docs/api.md
index 46116cb..c08acda 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -116,6 +116,7 @@ Cancel ongoing operations and make them return at earliest opportunity.
### close() ⇒ this
Closes the database connection.
+All statements created from this database are closed as well.
# class Statement
@@ -159,6 +160,10 @@ Executes the SQL statement and returns an iterator to the resulting rows.
| bindParameters | array of objects | The bind parameters for executing the statement. |
| queryOptions | object | Optional per-query overrides (for example, `{ queryTimeout: 100 }`). |
+### close() ⇒ this
+
+Closes the prepared statement and releases its resources.
+
### pluck([toggleState]) ⇒ this
This function is currently not supported.
diff --git a/index.d.ts b/index.d.ts
index 367e13e..0320c06 100644
--- a/index.d.ts
+++ b/index.d.ts
@@ -183,6 +183,8 @@ export declare class Statement {
columns(): unknown[]
safeIntegers(toggle?: boolean | undefined | null): this
interrupt(): void
+ /** Closes the statement. */
+ close(): void
}
/** A raw iterator over rows. The JavaScript layer wraps this in a iterable. */
export declare class RowsIterator {
diff --git a/integration-tests/tests/async.test.js b/integration-tests/tests/async.test.js
index cf5a205..01e4215 100644
--- a/integration-tests/tests/async.test.js
+++ b/integration-tests/tests/async.test.js
@@ -340,6 +340,30 @@ test.serial("Database.exec() after close()", async (t) => {
});
});
+test.serial("Statement.get() after Database.close()", async (t) => {
+ const db = t.context.db;
+ const stmt = await db.prepare("SELECT 1");
+ db.close();
+ await t.throwsAsync(async () => {
+ await stmt.get();
+ }, {
+ instanceOf: TypeError,
+ message: "The database connection is not open"
+ });
+});
+
+test.serial("Statement.close()", async (t) => {
+ const db = t.context.db;
+ const stmt = await db.prepare("SELECT 1");
+ stmt.close();
+ await t.throwsAsync(async () => {
+ await stmt.get();
+ }, {
+ instanceOf: TypeError,
+ message: "The database connection is not open"
+ });
+});
+
test.serial("Database.interrupt()", async (t) => {
const db = t.context.db;
const stmt = await db.prepare("WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;");
diff --git a/integration-tests/tests/sync.test.js b/integration-tests/tests/sync.test.js
index 8bbf18e..0dfe37e 100644
--- a/integration-tests/tests/sync.test.js
+++ b/integration-tests/tests/sync.test.js
@@ -433,6 +433,34 @@ test.serial("Database.exec() after close()", async (t) => {
});
});
+test.serial("Statement.get() after Database.close()", async (t) => {
+ const db = t.context.db;
+ const stmt = db.prepare("SELECT 1");
+ db.close();
+ t.throws(() => {
+ stmt.get();
+ }, {
+ instanceOf: TypeError,
+ message: "The database connection is not open"
+ });
+});
+
+test.serial("Statement.close()", async (t) => {
+ const db = t.context.db;
+ const stmt = db.prepare("SELECT 1");
+ if (typeof stmt.close !== "function") {
+ t.pass();
+ return;
+ }
+ stmt.close();
+ t.throws(() => {
+ stmt.get();
+ }, {
+ instanceOf: TypeError,
+ message: "The database connection is not open"
+ });
+});
+
test.serial("Timeout option", async (t) => {
const timeout = 1000;
const path = genDatabaseFilename();
diff --git a/promise.js b/promise.js
index f364013..d1013b1 100644
--- a/promise.js
+++ b/promise.js
@@ -61,6 +61,9 @@ function splitBindParameters(bindParameters) {
return { params: bindParameters.length === 1 ? bindParameters[0] : bindParameters, queryOptions: undefined };
}
+const symbolDispose = typeof Symbol.dispose === "symbol" ? Symbol.dispose : null;
+const hasWeakRef = typeof WeakRef === "function";
+
/**
* Creates a new database connection.
*
@@ -85,6 +88,8 @@ class Database {
constructor(db) {
this.db = db;
this.memory = this.db.memory
+ this._closed = false;
+ this._statements = hasWeakRef ? new Set() : null;
/** @type boolean */
this.inTransaction;
@@ -118,7 +123,13 @@ class Database {
async prepare(sql) {
try {
const stmt = await this.db.prepare(sql);
- return new Statement(stmt);
+ const wrappedStmt = new Statement(stmt, this);
+ if (this._statements != null) {
+ const statementRef = new WeakRef(wrappedStmt);
+ wrappedStmt._statementRef = statementRef;
+ this._statements.add(statementRef);
+ }
+ return wrappedStmt;
} catch (err) {
throw convertError(err);
}
@@ -256,6 +267,21 @@ class Database {
* Closes the database connection.
*/
close() {
+ if (this._closed) {
+ return;
+ }
+ this._closed = true;
+ if (this._statements != null) {
+ for (const statementRef of Array.from(this._statements)) {
+ const statement = statementRef.deref();
+ if (statement == null) {
+ this._statements.delete(statementRef);
+ continue;
+ }
+ statement.close();
+ }
+ this._statements.clear();
+ }
this.db.close();
}
@@ -285,8 +311,36 @@ class Statement {
/**
* @param {NativeStatement} stmt
*/
- constructor(stmt) {
+ constructor(stmt, database) {
this.stmt = stmt;
+ this.database = database;
+ this._statementRef = null;
+ this._closed = false;
+ }
+
+ close() {
+ if (this._closed) {
+ return this;
+ }
+ this._closed = true;
+ if (this.database != null && this.database._statements != null && this._statementRef != null) {
+ this.database._statements.delete(this._statementRef);
+ }
+ if (this.database != null) {
+ this.database = null;
+ }
+ if (this.stmt != null) {
+ this.stmt.close();
+ this.stmt = null;
+ }
+ return this;
+ }
+
+ _getNativeStatement() {
+ if (this._closed || this.stmt == null) {
+ throw new TypeError("The database connection is not open");
+ }
+ return this.stmt;
}
/**
@@ -295,7 +349,7 @@ class Statement {
* @param {boolean} [raw] - Enable or disable raw mode. If you don't pass the parameter, raw mode is enabled.
*/
raw(raw) {
- this.stmt.raw(raw);
+ this._getNativeStatement().raw(raw);
return this;
}
@@ -305,7 +359,7 @@ class Statement {
* @param {boolean} [pluckMode] - Enable or disable pluck mode. If you don't pass the parameter, pluck mode is enabled.
*/
pluck(pluckMode) {
- this.stmt.pluck(pluckMode);
+ this._getNativeStatement().pluck(pluckMode);
return this;
}
@@ -315,12 +369,12 @@ class Statement {
* @param {boolean} [timingMode] - Enable or disable query timing. If you don't pass the parameter, query timing is enabled.
*/
timing(timingMode) {
- this.stmt.timing(timingMode);
+ this._getNativeStatement().timing(timingMode);
return this;
}
get reader() {
- return this.stmt.columns().length > 0;
+ return this._getNativeStatement().columns().length > 0;
}
/**
@@ -329,7 +383,7 @@ class Statement {
async run(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- return await this.stmt.run(params, queryOptions);
+ return await this._getNativeStatement().run(params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -343,7 +397,7 @@ class Statement {
async get(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- return await this.stmt.get(params, queryOptions);
+ return await this._getNativeStatement().get(params, queryOptions);
} catch (err) {
throw convertError(err);
}
@@ -357,7 +411,7 @@ class Statement {
async iterate(...bindParameters) {
try {
const { params, queryOptions } = splitBindParameters(bindParameters);
- const it = await this.stmt.iterate(params, queryOptions);
+ const it = await this._getNativeStatement().iterate(params, queryOptions);
return {
next() {
return it.next();
@@ -398,7 +452,7 @@ class Statement {
* Interrupts the statement.
*/
interrupt() {
- this.stmt.interrupt();
+ this._getNativeStatement().interrupt();
return this;
}
@@ -406,18 +460,23 @@ class Statement {
* Returns the columns in the result set returned by this prepared statement.
*/
columns() {
- return this.stmt.columns();
+ return this._getNativeStatement().columns();
}
/**
* Toggle 64-bit integer support.
*/
safeIntegers(toggle) {
- this.stmt.safeIntegers(toggle);
+ this._getNativeStatement().safeIntegers(toggle);
return this;
}
}
+if (symbolDispose != null) {
+ Database.prototype[symbolDispose] = Database.prototype.close;
+ Statement.prototype[symbolDispose] = Statement.prototype.close;
+}
+
module.exports = {
Authorization,
Database,
diff --git a/src/lib.rs b/src/lib.rs
index 61e1581..3af0d22 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -664,6 +664,14 @@ fn throw_database_closed_error(env: &Env) -> napi::Error {
err
}
+fn throw_not_open_error() -> napi::Error {
+ throw_sqlite_error(
+ "The database connection is not open".to_string(),
+ "SQLITE_NOTOPEN".to_string(),
+ 0,
+ )
+}
+
fn query_timeout_duration(timeout_ms: f64) -> Option {
if timeout_ms.is_finite() && timeout_ms > 0.0 {
Some(Duration::from_millis(timeout_ms as u64))
@@ -676,9 +684,9 @@ fn query_timeout_duration(timeout_ms: f64) -> Option {
#[napi]
pub struct Statement {
// The libSQL connection instance.
- conn: Arc,
+ conn: Option>,
// The libSQL statement instance.
- stmt: Arc,
+ stmt: Option>,
// The column names.
column_names: Vec,
// The access mode.
@@ -712,8 +720,8 @@ impl Statement {
.collect();
let stmt = Arc::new(stmt);
Self {
- conn,
- stmt,
+ conn: Some(conn),
+ stmt: Some(stmt),
column_names,
mode,
query_timeout,
@@ -733,13 +741,13 @@ impl Statement {
params: Option,
query_options: Option,
) -> Result {
- self.stmt.reset();
- let params = map_params(&self.stmt, params)?;
- let total_changes_before = self.conn.total_changes();
+ let stmt = self.statement()?;
+ let conn = self.connection()?;
+ stmt.reset();
+ let params = map_params(stmt.as_ref(), params)?;
+ let total_changes_before = conn.total_changes();
let start = std::time::Instant::now();
- let stmt = self.stmt.clone();
- let conn = self.conn.clone();
- let guard = self.start_timeout_guard(query_options);
+ let guard = self.start_timeout_guard(&conn, query_options);
let future = async move {
let _guard = guard;
@@ -774,13 +782,14 @@ impl Statement {
params: Option,
query_options: Option,
) -> Result {
+ let stmt = self.statement()?;
+ let conn = self.connection()?;
let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst);
let raw = self.mode.raw.load(Ordering::SeqCst);
let pluck = self.mode.pluck.load(Ordering::SeqCst);
let timed = self.mode.timing.load(Ordering::SeqCst);
- let params = map_params(&self.stmt, params)?;
- let stmt = self.stmt.clone();
+ let params = map_params(stmt.as_ref(), params)?;
let column_names = self.column_names.clone();
let start = if timed {
@@ -790,7 +799,7 @@ impl Statement {
};
let stmt_fut = stmt.clone();
- let guard = self.start_timeout_guard(query_options);
+ let guard = self.start_timeout_guard(&conn, query_options);
let future = async move {
let _guard = guard;
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
@@ -854,14 +863,14 @@ impl Statement {
params: Option,
query_options: Option,
) -> Result {
+ let stmt = self.statement()?;
+ let conn = self.connection()?;
let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst);
let raw = self.mode.raw.load(Ordering::SeqCst);
let pluck = self.mode.pluck.load(Ordering::SeqCst);
- let stmt = self.stmt.clone();
stmt.reset();
- let params = map_params(&stmt, params).unwrap();
- let stmt = self.stmt.clone();
- let guard = self.start_timeout_guard(query_options);
+ let params = map_params(stmt.as_ref(), params)?;
+ let guard = self.start_timeout_guard(&conn, query_options);
let future = async move {
let rows = stmt.query(params).await.map_err(Error::from)?;
Ok::<_, napi::Error>(rows)
@@ -881,7 +890,8 @@ impl Statement {
#[napi]
pub fn raw(&self, raw: Option) -> Result<&Self> {
- let returns_data = !self.stmt.columns().is_empty();
+ let stmt = self.statement()?;
+ let returns_data = !stmt.columns().is_empty();
if !returns_data {
return Err(napi::Error::from_reason(
"The raw() method is only for statements that return data",
@@ -909,7 +919,8 @@ impl Statement {
#[napi]
pub fn columns(&self, env: Env) -> Result {
- let columns = self.stmt.columns();
+ let stmt = self.statement()?;
+ let columns = stmt.columns();
let mut js_array = env.create_array(columns.len() as u32)?;
for (i, col) in columns.iter().enumerate() {
let mut js_obj = env.create_object()?;
@@ -953,12 +964,29 @@ impl Statement {
#[napi]
pub fn interrupt(&self) -> Result<()> {
- self.stmt.interrupt().map_err(Error::from)?;
+ let stmt = self.statement()?;
+ stmt.interrupt().map_err(Error::from)?;
+ Ok(())
+ }
+
+ /// Closes the statement.
+ #[napi]
+ pub fn close(&mut self) -> Result<()> {
+ self.stmt = None;
+ self.conn = None;
Ok(())
}
}
impl Statement {
+ fn statement(&self) -> Result> {
+ self.stmt.clone().ok_or_else(throw_not_open_error)
+ }
+
+ fn connection(&self) -> Result> {
+ self.conn.clone().ok_or_else(throw_not_open_error)
+ }
+
fn resolve_query_timeout(&self, query_options: Option) -> Option {
match query_options.and_then(|o| o.queryTimeout) {
Some(timeout_ms) => query_timeout_duration(timeout_ms),
@@ -966,9 +994,13 @@ impl Statement {
}
}
- fn start_timeout_guard(&self, query_options: Option) -> Option {
+ fn start_timeout_guard(
+ &self,
+ conn: &Arc,
+ query_options: Option,
+ ) -> Option {
self.resolve_query_timeout(query_options)
- .map(|t| self.timeout_manager.register(&self.conn, t))
+ .map(|t| self.timeout_manager.register(conn, t))
}
}
@@ -980,6 +1012,8 @@ pub fn statement_get_sync(
params: Option,
query_options: Option,
) -> Result {
+ let inner_stmt = stmt.statement()?;
+ let conn = stmt.connection()?;
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
let raw = stmt.mode.raw.load(Ordering::SeqCst);
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
@@ -992,10 +1026,10 @@ pub fn statement_get_sync(
};
let rt = runtime()?;
- let _guard = stmt.start_timeout_guard(query_options);
+ let _guard = stmt.start_timeout_guard(&conn, query_options);
rt.block_on(async move {
- let params = map_params(&stmt.stmt, params)?;
- let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
+ let params = map_params(inner_stmt.as_ref(), params)?;
+ let mut rows = inner_stmt.query(params).await.map_err(Error::from)?;
let row = rows.next().await.map_err(Error::from)?;
let duration: Option = start.map(|start| start.elapsed().as_secs_f64());
let result = Statement::get_internal(
@@ -1007,7 +1041,7 @@ pub fn statement_get_sync(
pluck,
duration,
);
- stmt.stmt.reset();
+ inner_stmt.reset();
result
})
}
@@ -1019,21 +1053,23 @@ pub fn statement_run_sync(
params: Option,
query_options: Option,
) -> Result {
- stmt.stmt.reset();
+ let inner_stmt = stmt.statement()?;
+ let conn = stmt.connection()?;
+ inner_stmt.reset();
let rt = runtime()?;
- let _guard = stmt.start_timeout_guard(query_options);
+ let _guard = stmt.start_timeout_guard(&conn, query_options);
rt.block_on(async move {
- let params = map_params(&stmt.stmt, params)?;
- let total_changes_before = stmt.conn.total_changes();
+ let params = map_params(inner_stmt.as_ref(), params)?;
+ let total_changes_before = conn.total_changes();
let start = std::time::Instant::now();
- stmt.stmt.run(params).await.map_err(Error::from)?;
- let changes = if stmt.conn.total_changes() == total_changes_before {
+ inner_stmt.run(params).await.map_err(Error::from)?;
+ let changes = if conn.total_changes() == total_changes_before {
0
} else {
- stmt.conn.changes()
+ conn.changes()
};
- let last_insert_row_id = stmt.conn.last_insert_rowid();
+ let last_insert_row_id = conn.last_insert_rowid();
let duration = start.elapsed().as_secs_f64();
Ok(RunResult {
changes: changes as f64,
@@ -1051,14 +1087,15 @@ pub fn statement_iterate_sync(
query_options: Option,
) -> Result {
let rt = runtime()?;
+ let conn = stmt.connection()?;
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
let raw = stmt.mode.raw.load(Ordering::SeqCst);
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
- let guard = stmt.start_timeout_guard(query_options);
- let inner_stmt = stmt.stmt.clone();
+ let guard = stmt.start_timeout_guard(&conn, query_options);
+ let inner_stmt = stmt.statement()?;
let (rows, column_names) = rt.block_on(async move {
inner_stmt.reset();
- let params = map_params(&inner_stmt, params)?;
+ let params = map_params(inner_stmt.as_ref(), params)?;
let rows = inner_stmt.query(params).await.map_err(Error::from)?;
let mut column_names = Vec::new();
for i in 0..rows.column_count() {