diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 20281a7ab1..8e88a04140 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,6 +23,7 @@ jobs: runs-on: ubuntu-latest name: Run Checks env: + RUST_BACKTRACE: 1 RUSTFLAGS: -D warnings --cfg tokio_unstable steps: - uses: hecrj/setup-rust-action@v2 @@ -179,6 +180,7 @@ jobs: runs-on: ubuntu-latest name: Run Tests env: + RUST_BACKTRACE: 1 RUSTFLAGS: -D warnings --cfg tokio_unstable steps: - uses: hecrj/setup-rust-action@v2 diff --git a/libsql-server/tests/standalone/mod.rs b/libsql-server/tests/standalone/mod.rs index 046f147d4d..ad3fae958b 100644 --- a/libsql-server/tests/standalone/mod.rs +++ b/libsql-server/tests/standalone/mod.rs @@ -422,7 +422,7 @@ async fn insert_rows(conn: &Connection, start: u32, count: u32) -> libsql::Resul async fn insert_rows_with_args(conn: &Connection, start: u32, count: u32) -> libsql::Result<()> { for i in start..(start + count) { - let mut stmt = conn.prepare("INSERT INTO test(a, b) VALUES(?,?)").await?; + let stmt = conn.prepare("INSERT INTO test(a, b) VALUES(?,?)").await?; stmt.execute(params![i, i]).await?; } Ok(()) diff --git a/libsql/benches/benchmark.rs b/libsql/benches/benchmark.rs index bdc1f56f7d..69905c9d33 100644 --- a/libsql/benches/benchmark.rs +++ b/libsql/benches/benchmark.rs @@ -93,7 +93,7 @@ fn bench(c: &mut Criterion) { group.bench_function("in-memory-select-1-prepared", |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT 1")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); @@ -113,7 +113,7 @@ fn bench(c: &mut Criterion) { group.bench_function("in-memory-select-star-from-users-limit-1-unprepared", |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT * FROM users LIMIT 1")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); @@ -128,7 +128,7 @@ fn bench(c: &mut Criterion) { |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT * FROM users LIMIT 100")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); @@ -156,7 +156,7 @@ fn bench(c: &mut Criterion) { group.bench_function("local-replica-select-1-prepared", |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT 1")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); @@ -188,7 +188,7 @@ fn bench(c: &mut Criterion) { |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT * FROM users LIMIT 1")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); @@ -204,7 +204,7 @@ fn bench(c: &mut Criterion) { |b| { b.to_async(&rt).iter_batched( || block_on(conn.prepare("SELECT * FROM users LIMIT 100")).unwrap(), - |mut stmt| async move { + |stmt| async move { let mut rows = stmt.query(()).await.unwrap(); let row = rows.next().await.unwrap().unwrap(); assert_eq!(row.get::(0).unwrap(), 1); diff --git a/libsql/examples/deserialization.rs b/libsql/examples/deserialization.rs index abedec1f14..b1c8ce4780 100644 --- a/libsql/examples/deserialization.rs +++ b/libsql/examples/deserialization.rs @@ -22,7 +22,7 @@ async fn main() { .await .unwrap(); - let mut stmt = conn + let stmt = conn .prepare("INSERT INTO users (name, age, vision, avatar) VALUES (?1, ?2, ?3, ?4)") .await .unwrap(); @@ -30,7 +30,7 @@ async fn main() { .await .unwrap(); - let mut stmt = conn + let stmt = conn .prepare("SELECT * FROM users WHERE name = ?1") .await .unwrap(); diff --git a/libsql/examples/example.rs b/libsql/examples/example.rs index ab8e3f5c2f..1416d550d1 100644 --- a/libsql/examples/example.rs +++ b/libsql/examples/example.rs @@ -21,14 +21,14 @@ async fn main() { .await .unwrap(); - let mut stmt = conn + let stmt = conn .prepare("INSERT INTO users (email) VALUES (?1)") .await .unwrap(); stmt.execute(["foo@example.com"]).await.unwrap(); - let mut stmt = conn + let stmt = conn .prepare("SELECT * FROM users WHERE email = ?1") .await .unwrap(); diff --git a/libsql/examples/example_v2.rs b/libsql/examples/example_v2.rs index d434618aff..85815d3deb 100644 --- a/libsql/examples/example_v2.rs +++ b/libsql/examples/example_v2.rs @@ -21,14 +21,14 @@ async fn main() { .await .unwrap(); - let mut stmt = conn + let stmt = conn .prepare("INSERT INTO users (email) VALUES (?1)") .await .unwrap(); stmt.execute(["foo@example.com"]).await.unwrap(); - let mut stmt = conn + let stmt = conn .prepare("SELECT * FROM users WHERE email = ?1") .await .unwrap(); diff --git a/libsql/examples/flutter.rs b/libsql/examples/flutter.rs index 10eabf1f5c..ba928bcf30 100644 --- a/libsql/examples/flutter.rs +++ b/libsql/examples/flutter.rs @@ -30,14 +30,14 @@ async fn main() { .await .unwrap(); - let mut stmt = conn + let stmt = conn .prepare("INSERT INTO users (email) VALUES (?1)") .await .unwrap(); stmt.execute(["foo@example.com"]).await.unwrap(); - let mut stmt = conn + let stmt = conn .prepare("SELECT * FROM users WHERE email = ?1") .await .unwrap(); diff --git a/libsql/src/connection.rs b/libsql/src/connection.rs index a164321dc2..e9831411d6 100644 --- a/libsql/src/connection.rs +++ b/libsql/src/connection.rs @@ -170,7 +170,7 @@ impl Connection { /// For more info on how to pass params check [`IntoParams`]'s docs and on how to /// extract values out of the rows check the [`Rows`] docs. pub async fn query(&self, sql: &str, params: impl IntoParams) -> Result { - let mut stmt = self.prepare(sql).await?; + let stmt = self.prepare(sql).await?; stmt.query(params).await } diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index d32341796b..300602c27e 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -237,25 +237,25 @@ impl Conn for HttpConnection { impl crate::statement::Stmt for crate::hrana::Statement { fn finalize(&mut self) {} - async fn execute(&mut self, params: &Params) -> crate::Result { + async fn execute(&self, params: &Params) -> crate::Result { self.execute(params).await } - async fn query(&mut self, params: &Params) -> crate::Result { + async fn query(&self, params: &Params) -> crate::Result { self.query(params).await } - async fn run(&mut self, params: &Params) -> crate::Result<()> { + async fn run(&self, params: &Params) -> crate::Result<()> { self.run(params).await } - fn interrupt(&mut self) -> crate::Result<()> { + fn interrupt(&self) -> crate::Result<()> { Err(crate::Error::Misuse( "interrupt is not supported for remote connections".to_string(), )) } - fn reset(&mut self) {} + fn reset(&self) {} fn parameter_count(&self) -> usize { let stmt = &self.inner; diff --git a/libsql/src/hrana/mod.rs b/libsql/src/hrana/mod.rs index b29d5e4d34..129cf154b6 100644 --- a/libsql/src/hrana/mod.rs +++ b/libsql/src/hrana/mod.rs @@ -162,7 +162,7 @@ where } } - pub async fn execute(&mut self, params: &Params) -> crate::Result { + pub async fn execute(&self, params: &Params) -> crate::Result { let mut stmt = self.inner.clone(); bind_params(params.clone(), &mut stmt); @@ -170,7 +170,7 @@ where Ok(result.affected_row_count as usize) } - pub async fn run(&mut self, params: &Params) -> crate::Result<()> { + pub async fn run(&self, params: &Params) -> crate::Result<()> { let mut stmt = self.inner.clone(); bind_params(params.clone(), &mut stmt); @@ -179,7 +179,7 @@ where } pub(crate) async fn query_raw( - &mut self, + &self, params: &Params, ) -> crate::Result> { let mut stmt = self.inner.clone(); @@ -197,7 +197,7 @@ where T: HttpSend + Send + Sync + 'static, ::Stream: Send + Sync + 'static, { - pub async fn query(&mut self, params: &Params) -> crate::Result { + pub async fn query(&self, params: &Params) -> crate::Result { let rows = self.query_raw(params).await?; Ok(super::Rows::new(rows)) } diff --git a/libsql/src/hrana/stream.rs b/libsql/src/hrana/stream.rs index c63f600c75..686013d3c5 100644 --- a/libsql/src/hrana/stream.rs +++ b/libsql/src/hrana/stream.rs @@ -8,6 +8,7 @@ use libsql_hrana::proto::{ GetAutocommitStreamReq, PipelineReqBody, PipelineRespBody, SequenceStreamReq, StoreSqlStreamReq, StreamRequest, StreamResponse, StreamResult, }; +use std::cell::RefCell; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; use std::sync::Arc; use tokio::sync::Mutex; @@ -66,8 +67,8 @@ where pipeline_url, cursor_url, auth_token, - sql_id_generator: 0, - baton: None, + sql_id_generator: RefCell::new(0), + baton: RefCell::new(None), }), }), } @@ -77,7 +78,7 @@ where /// Returns true if request was finalized correctly, false if stream was already closed. pub(super) async fn finalize(&mut self, req: StreamRequest) -> Result { let mut client = self.inner.stream.lock().await; - if client.baton.is_none() { + if client.baton.borrow().is_none() { tracing::trace!("baton not found - skipping finalize for {:?}", req); return Ok(false); } @@ -298,11 +299,11 @@ where T: HttpSend, { client: T, - baton: Option, + baton: RefCell>, pipeline_url: Arc, cursor_url: Arc, auth_token: Arc, - sql_id_generator: SqlId, + sql_id_generator: RefCell, } impl RawStream @@ -316,7 +317,7 @@ where pub async fn open_cursor(&mut self, batch: Batch) -> Result> { let msg = CursorReq { - baton: self.baton.clone(), + baton: self.baton.borrow().clone(), batch, }; let body = serde_json::to_string(&msg).map_err(HranaError::Json)?; @@ -336,7 +337,7 @@ where } // stream has been closed by the server Some(baton) => { tracing::trace!("client stream has been assigned with baton: `{}`", baton); - self.baton = Some(baton) + *self.baton.borrow_mut() = Some(baton) } } Ok(cursor) @@ -349,11 +350,11 @@ where tracing::trace!( "client stream sending {} requests with baton `{}`: {:?}", N, - self.baton.as_deref().unwrap_or_default(), + self.baton.borrow().as_deref().unwrap_or_default(), requests ); let msg = PipelineReqBody { - baton: self.baton.clone(), + baton: self.baton.borrow().clone(), requests: Vec::from(requests), }; let body = serde_json::to_string(&msg).map_err(HranaError::Json)?; @@ -375,7 +376,7 @@ where } // stream has been closed by the server Some(baton) => { tracing::trace!("client stream has been assigned with baton: `{}`", baton); - self.baton = Some(baton) + *self.baton.borrow_mut() = Some(baton) } } @@ -424,16 +425,17 @@ where Ok((resp, is_autocommit)) } - fn reset(&mut self) { - if let Some(baton) = self.baton.take() { + fn reset(&self) { + if let Some(baton) = self.baton.borrow_mut().take() { tracing::trace!("closing client stream (baton: `{}`)", baton); } - self.sql_id_generator = 0; + *self.sql_id_generator.borrow_mut() = 0; } fn next_sql_id(&mut self) -> SqlId { - self.sql_id_generator = self.sql_id_generator.wrapping_add(1); - self.sql_id_generator + let mut gen = self.sql_id_generator.borrow_mut(); + *gen = gen.wrapping_add(1); + *gen } } @@ -443,7 +445,8 @@ where T: HttpSend, { fn drop(&mut self) { - if let Some(baton) = self.baton.take() { + let baton = self.baton.get_mut().take(); + if let Some(baton) = baton { // only send a close request if stream was ever used to send the data tracing::trace!("closing client stream (baton: `{}`)", baton); let req = serde_json::to_string(&PipelineReqBody { diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 65ade9dcfe..0445e8c94b 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -108,32 +108,32 @@ impl Stmt for LibsqlStmt { self.0.finalize(); } - async fn execute(&mut self, params: &Params) -> Result { + async fn execute(&self, params: &Params) -> Result { let params = params.clone(); let stmt = self.0.clone(); stmt.execute(¶ms).map(|i| i as usize) } - async fn query(&mut self, params: &Params) -> Result { + async fn query(&self, params: &Params) -> Result { let params = params.clone(); let stmt = self.0.clone(); stmt.query(¶ms).map(LibsqlRows).map(Rows::new) } - async fn run(&mut self, params: &Params) -> Result<()> { + async fn run(&self, params: &Params) -> Result<()> { let params = params.clone(); let stmt = self.0.clone(); stmt.run(¶ms) } - fn interrupt(&mut self) -> Result<()> { + fn interrupt(&self) -> Result<()> { self.0.interrupt() } - fn reset(&mut self) { + fn reset(&self) { self.0.reset(); } diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index c9f0eac242..50fd4dd939 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -654,8 +654,8 @@ async fn fetch_metas( impl Stmt for RemoteStatement { fn finalize(&mut self) {} - async fn execute(&mut self, params: &Params) -> Result { - if let Some(stmt) = &mut self.local_statement { + async fn execute(&self, params: &Params) -> Result { + if let Some(stmt) = &self.local_statement { return stmt.execute(params.clone()).await; } @@ -688,8 +688,8 @@ impl Stmt for RemoteStatement { Ok(affected_row_count as usize) } - async fn query(&mut self, params: &Params) -> Result { - if let Some(stmt) = &mut self.local_statement { + async fn query(&self, params: &Params) -> Result { + if let Some(stmt) = &self.local_statement { return stmt.query(params.clone()).await; } @@ -722,8 +722,8 @@ impl Stmt for RemoteStatement { Ok(Rows::new(RemoteRows(rows, 0))) } - async fn run(&mut self, params: &Params) -> Result<()> { - if let Some(stmt) = &mut self.local_statement { + async fn run(&self, params: &Params) -> Result<()> { + if let Some(stmt) = &self.local_statement { return stmt.run(params.clone()).await; } @@ -749,13 +749,13 @@ impl Stmt for RemoteStatement { Ok(()) } - fn interrupt(&mut self) -> Result<()> { + fn interrupt(&self) -> Result<()> { Err(Error::Misuse( "interrupt is not supported for remote connections".to_string(), )) } - fn reset(&mut self) {} + fn reset(&self) {} fn parameter_count(&self) -> usize { if let Some(stmt) = self.local_statement.as_ref() { diff --git a/libsql/src/statement.rs b/libsql/src/statement.rs index fa451350e9..861fdf8023 100644 --- a/libsql/src/statement.rs +++ b/libsql/src/statement.rs @@ -8,15 +8,15 @@ use crate::{Row, Rows}; pub(crate) trait Stmt { fn finalize(&mut self); - async fn execute(&mut self, params: &Params) -> Result; + async fn execute(&self, params: &Params) -> Result; - async fn query(&mut self, params: &Params) -> Result; + async fn query(&self, params: &Params) -> Result; - async fn run(&mut self, params: &Params) -> Result<()>; + async fn run(&self, params: &Params) -> Result<()>; - fn interrupt(&mut self) -> Result<()>; + fn interrupt(&self) -> Result<()>; - fn reset(&mut self); + fn reset(&self); fn parameter_count(&self) -> usize; @@ -39,13 +39,13 @@ impl Statement { } /// Execute queries on the statement, check [`Connection::execute`] for usage. - pub async fn execute(&mut self, params: impl IntoParams) -> Result { + pub async fn execute(&self, params: impl IntoParams) -> Result { tracing::trace!("execute for prepared statement"); self.inner.execute(¶ms.into_params()?).await } /// Execute a query on the statement, check [`Connection::query`] for usage. - pub async fn query(&mut self, params: impl IntoParams) -> Result { + pub async fn query(&self, params: impl IntoParams) -> Result { tracing::trace!("query for prepared statement"); self.inner.query(¶ms.into_params()?).await } @@ -58,14 +58,14 @@ impl Statement { /// provided to execute any type of SQL statement. /// /// Note: This is an extension to the Rusqlite API. - pub async fn run(&mut self, params: impl IntoParams) -> Result<()> { + pub async fn run(&self, params: impl IntoParams) -> Result<()> { tracing::trace!("run for prepared statement"); self.inner.run(¶ms.into_params()?).await?; Ok(()) } /// Interrupt the statement. - pub fn interrupt(&mut self) -> Result<()> { + pub fn interrupt(&self) -> Result<()> { self.inner.interrupt() } @@ -83,7 +83,7 @@ impl Statement { } /// Reset the state of this prepared statement. - pub fn reset(&mut self) { + pub fn reset(&self) { self.inner.reset(); } diff --git a/libsql/src/sync/connection.rs b/libsql/src/sync/connection.rs index c2809c0f57..1150af63e0 100644 --- a/libsql/src/sync/connection.rs +++ b/libsql/src/sync/connection.rs @@ -104,7 +104,7 @@ impl SyncedConnection { #[async_trait::async_trait] impl Conn for SyncedConnection { async fn execute(&self, sql: &str, params: Params) -> Result { - let mut stmt = self.prepare(sql).await?; + let stmt = self.prepare(sql).await?; stmt.execute(params).await.map(|v| v as u64) } diff --git a/libsql/src/sync/statement.rs b/libsql/src/sync/statement.rs index ad2183b0f4..eebd78909c 100644 --- a/libsql/src/sync/statement.rs +++ b/libsql/src/sync/statement.rs @@ -20,7 +20,7 @@ impl Stmt for SyncedStatement { self.inner.finalize() } - async fn execute(&mut self, params: &Params) -> Result { + async fn execute(&self, params: &Params) -> Result { if self.needs_pull.load(Ordering::Relaxed) { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; @@ -29,7 +29,7 @@ impl Stmt for SyncedStatement { self.inner.execute(params).await } - async fn query(&mut self, params: &Params) -> Result { + async fn query(&self, params: &Params) -> Result { if self.needs_pull.load(Ordering::Relaxed) { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; @@ -38,7 +38,7 @@ impl Stmt for SyncedStatement { self.inner.query(params).await } - async fn run(&mut self, params: &Params) -> Result<()> { + async fn run(&self, params: &Params) -> Result<()> { if self.needs_pull.load(Ordering::Relaxed) { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; @@ -47,11 +47,11 @@ impl Stmt for SyncedStatement { self.inner.run(params).await } - fn interrupt(&mut self) -> Result<()> { + fn interrupt(&self) -> Result<()> { self.inner.interrupt() } - fn reset(&mut self) { + fn reset(&self) { self.inner.reset() } diff --git a/libsql/src/wasm/mod.rs b/libsql/src/wasm/mod.rs index 6bcd133303..18a0a44e05 100644 --- a/libsql/src/wasm/mod.rs +++ b/libsql/src/wasm/mod.rs @@ -70,7 +70,7 @@ where { pub async fn execute(&self, sql: &str, params: impl IntoParams) -> crate::Result { tracing::trace!("executing `{}`", sql); - let mut stmt = crate::hrana::Statement::new( + let stmt = crate::hrana::Statement::new( self.conn.current_stream().clone(), sql.to_string(), true, @@ -100,7 +100,7 @@ where pub async fn query(&self, sql: &str, params: impl IntoParams) -> crate::Result { tracing::trace!("querying `{}`", sql); - let mut stmt = crate::hrana::Statement::new( + let stmt = crate::hrana::Statement::new( self.conn.current_stream().clone(), sql.to_string(), true, @@ -139,7 +139,7 @@ where pub async fn query(&self, sql: &str, params: impl IntoParams) -> crate::Result { tracing::trace!("querying `{}`", sql); let stream = self.inner.stream().clone(); - let mut stmt = crate::hrana::Statement::new(stream, sql.to_string(), true).await?; + let stmt = crate::hrana::Statement::new(stream, sql.to_string(), true).await?; let rows = stmt.query_raw(¶ms.into_params()?).await?; Ok(Rows { inner: Box::new(rows), @@ -149,7 +149,7 @@ where pub async fn execute(&self, sql: &str, params: impl IntoParams) -> crate::Result { tracing::trace!("executing `{}`", sql); let stream = self.inner.stream().clone(); - let mut stmt = crate::hrana::Statement::new(stream, sql.to_string(), true).await?; + let stmt = crate::hrana::Statement::new(stream, sql.to_string(), true).await?; let rows = stmt.execute(¶ms.into_params()?).await?; Ok(rows as u64) } diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 2101a2e1ea..57addab948 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -600,7 +600,7 @@ async fn debug_print_row() { .await .unwrap(); - let mut stmt = conn.prepare("SELECT * FROM users").await.unwrap(); + let stmt = conn.prepare("SELECT * FROM users").await.unwrap(); let mut rows = stmt.query(()).await.unwrap(); assert_eq!( format!("{:?}", rows.next().await.unwrap().unwrap()),