diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 322913eefc..11fa5dcd28 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -719,7 +719,6 @@ impl Database { read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), state: std::sync::Arc::new(Mutex::new(State::Init)), - needs_pull: std::sync::atomic::AtomicBool::new(false).into(), }; let conn = std::sync::Arc::new(synced); diff --git a/libsql/src/sync/connection.rs b/libsql/src/sync/connection.rs index 1150af63e0..5c48e788d1 100644 --- a/libsql/src/sync/connection.rs +++ b/libsql/src/sync/connection.rs @@ -8,10 +8,7 @@ use crate::{ sync::SyncContext, BatchRows, Error, Result, Statement, Transaction, TransactionBehavior, }; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; +use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; @@ -24,7 +21,6 @@ pub struct SyncedConnection { pub read_your_writes: bool, pub context: Arc>, pub state: Arc>, - pub needs_pull: Arc, } impl SyncedConnection { @@ -110,35 +106,39 @@ impl Conn for SyncedConnection { async fn execute_batch(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { - if self.needs_pull.load(Ordering::Relaxed) { + self.local.execute_batch(sql) + } else { + let result = self.remote.execute_batch(sql).await; + if self.read_your_writes { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.local).await?; - self.needs_pull.store(false, Ordering::Relaxed); } - self.local.execute_batch(sql) - } else { - self.remote.execute_batch(sql).await + result } } async fn execute_transactional_batch(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { - if self.needs_pull.load(Ordering::Relaxed) { - let mut context = self.context.lock().await; - crate::sync::try_pull(&mut context, &self.local).await?; - self.needs_pull.store(false, Ordering::Relaxed); - } self.local.execute_transactional_batch(sql)?; Ok(BatchRows::empty()) } else { - self.remote.execute_transactional_batch(sql).await + let result = self.remote.execute_transactional_batch(sql).await; + if self.read_your_writes { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.local).await?; + } + result } } async fn prepare(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { - let stmt = Statement { + Ok(Statement { inner: Box::new(LibsqlStmt(self.local.prepare(sql)?)), + }) + } else { + let stmt = Statement { + inner: Box::new(self.remote.prepare(sql).await?), }; Ok(Statement { @@ -146,19 +146,9 @@ impl Conn for SyncedConnection { conn: self.local.clone(), inner: stmt, context: self.context.clone(), - needs_pull: self.needs_pull.clone(), + read_your_writes: self.read_your_writes, }), }) - } else { - let stmt = Statement { - inner: Box::new(self.remote.prepare(sql).await?), - }; - - if self.read_your_writes { - self.needs_pull.store(true, Ordering::Relaxed); - } - - Ok(stmt) } } diff --git a/libsql/src/sync/statement.rs b/libsql/src/sync/statement.rs index eebd78909c..b3de1338ef 100644 --- a/libsql/src/sync/statement.rs +++ b/libsql/src/sync/statement.rs @@ -4,14 +4,14 @@ use crate::{ statement::Stmt, sync::SyncContext, Column, Result, Rows, Statement, }; -use std::sync::{atomic::{AtomicBool, Ordering}, Arc}; +use std::sync::Arc; use tokio::sync::Mutex; pub struct SyncedStatement { pub conn: local::Connection, pub inner: Statement, pub context: Arc>, - pub needs_pull: Arc, + pub read_your_writes: bool, } #[async_trait::async_trait] @@ -21,30 +21,30 @@ impl Stmt for SyncedStatement { } async fn execute(&self, params: &Params) -> Result { - if self.needs_pull.load(Ordering::Relaxed) { + let result = self.inner.execute(params).await; + if self.read_your_writes { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; - self.needs_pull.store(false, Ordering::Relaxed); } - self.inner.execute(params).await + result } async fn query(&self, params: &Params) -> Result { - if self.needs_pull.load(Ordering::Relaxed) { + let result = self.inner.query(params).await; + if self.read_your_writes { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; - self.needs_pull.store(false, Ordering::Relaxed); } - self.inner.query(params).await + result } async fn run(&self, params: &Params) -> Result<()> { - if self.needs_pull.load(Ordering::Relaxed) { + let result = self.inner.run(params).await; + if self.read_your_writes { let mut context = self.context.lock().await; crate::sync::try_pull(&mut context, &self.conn).await?; - self.needs_pull.store(false, Ordering::Relaxed); } - self.inner.run(params).await + result } fn interrupt(&self) -> Result<()> {