diff --git a/Cargo.lock b/Cargo.lock index 735ca20..1a3c398 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,6 +347,7 @@ checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ "iana-time-zone", "num-traits", + "serde", "windows-link", ] @@ -707,6 +708,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "uuid", "which", ] @@ -2308,10 +2310,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef4605b7c057056dd35baeb6ac0c0338e4975b1f2bef0f65da953285eb007095" dependencies = [ "bytes", + "chrono", "fallible-iterator 0.2.0", "postgres-protocol", "serde_core", "serde_json", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index eb9bbe1..40fa659 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,8 @@ categories = ["command-line-utilities", "database"] [dependencies] tokio = { version = "1.35", features = ["full"] } -tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] } +tokio-postgres = { version = "0.7", features = ["with-serde_json-1", "with-chrono-0_4", "with-uuid-1"] } +uuid = { version = "1", features = ["v4"] } clap = { version = "4.4", features = ["derive", "env"] } anyhow = "1.0" tracing = "0.1" @@ -40,7 +41,7 @@ bson = "2.9" mysql_async = "0.34" dirs = "5.0" url = "2.5" -chrono = { version = "0.4", default-features = false, features = ["clock"] } +chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } [[test]] name = "fallback_test" diff --git a/Dockerfile b/Dockerfile index b599ea6..bd302dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1 -FROM debian:bookworm-slim AS downloader +FROM ubuntu:24.04 AS downloader ARG VERSION=latest ENV BINARY_NAME=database-replicator-linux-x64-binary ENV RELEASE_ROOT=https://github.com/serenorg/database-replicator/releases @@ -16,13 +16,13 @@ RUN set -eux; \ curl -fL "$URL" -o /tmp/database-replicator && \ chmod +x /tmp/database-replicator -FROM debian:bookworm-slim +FROM ubuntu:24.04 LABEL org.opencontainers.image.title="database-replicator" \ org.opencontainers.image.description="Seren database replicator CLI" \ org.opencontainers.image.source="https://github.com/serenorg/database-replicator" RUN apt-get update && \ - apt-get install -y --no-install-recommends ca-certificates libssl3 libpq5 postgresql-client && \ + apt-get install -y --no-install-recommends ca-certificates libsqlite3-0 libssl3 libpq5 postgresql-client && \ rm -rf /var/lib/apt/lists/* && \ useradd -m replicator diff --git a/src/lib.rs b/src/lib.rs index c311d38..6263910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ pub mod sqlite; pub mod state; pub mod table_rules; pub mod utils; +pub mod xmin; use anyhow::{bail, Result}; diff --git a/src/main.rs b/src/main.rs index 2af3061..945eefe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -188,6 +188,38 @@ enum Commands { #[command(flatten)] args: commands::target::TargetArgs, }, + /// Run xmin-based incremental sync (alternative to logical replication) + #[command(name = "xmin-sync")] + XminSync { + #[arg(long)] + source: String, + #[arg(long)] + target: Option, + /// Schema to sync (defaults to "public") + #[arg(long, default_value = "public")] + schema: String, + /// Tables to sync (comma-separated, syncs all if not specified) + #[arg(long, value_delimiter = ',')] + tables: Option>, + /// Sync interval in seconds (default: 60) + #[arg(long, default_value_t = 60)] + interval: u64, + /// Reconciliation interval in seconds for delete detection (default: 3600 = 1 hour) + #[arg(long, default_value_t = 3600)] + reconcile_interval: u64, + /// Batch size for reading changes (default: 1000) + #[arg(long, default_value_t = 1000)] + batch_size: usize, + /// Path to state file for tracking sync progress + #[arg(long)] + state_file: Option, + /// Run a single sync cycle then exit (useful for cron jobs) + #[arg(long)] + once: bool, + /// Skip reconciliation (delete detection) + #[arg(long)] + no_reconcile: bool, + }, } #[tokio::main] @@ -583,6 +615,37 @@ async fn main() -> anyhow::Result<()> { commands::verify(&source, &target, Some(filter)).await } Commands::Target { args } => commands::target(args).await, + Commands::XminSync { + source, + target, + schema, + tables, + interval, + reconcile_interval, + batch_size, + state_file, + once, + no_reconcile, + } => { + let state = database_replicator::state::load()?; + let target = target.or(state.target_url).ok_or_else(|| { + anyhow::anyhow!("Target database URL not provided and not set in state. Use `--target` or `database-replicator target set`.") + })?; + + xmin_sync( + source, + target, + schema, + tables, + interval, + reconcile_interval, + batch_size, + state_file, + once, + no_reconcile, + ) + .await + } } } @@ -1000,3 +1063,131 @@ enum SerenTargetMode { Project, Url, } + +/// Run xmin-based incremental sync between source and target databases +#[allow(clippy::too_many_arguments)] +async fn xmin_sync( + source: String, + target: String, + schema: String, + tables: Option>, + interval: u64, + reconcile_interval: u64, + batch_size: usize, + state_file: Option, + once: bool, + no_reconcile: bool, +) -> anyhow::Result<()> { + use database_replicator::xmin::{DaemonConfig, SyncDaemon, SyncState}; + use std::path::PathBuf; + use std::time::Duration; + + tracing::info!("Starting xmin-based sync..."); + tracing::info!( + "Source: {}", + database_replicator::utils::strip_password_from_url(&source).unwrap_or_else(|_| source.clone()) + ); + tracing::info!( + "Target: {}", + database_replicator::utils::strip_password_from_url(&target).unwrap_or_else(|_| target.clone()) + ); + tracing::info!("Schema: {}", schema); + if let Some(ref t) = tables { + tracing::info!("Tables: {}", t.join(", ")); + } else { + tracing::info!("Tables: all"); + } + + // CRITICAL: Ensure source and target are different to prevent data loss + database_replicator::utils::validate_source_target_different(&source, &target) + .context("Source and target validation failed")?; + tracing::info!("Verified source and target are different databases"); + + // Build daemon config + let state_path = state_file + .map(PathBuf::from) + .unwrap_or_else(SyncState::default_path); + + let reconcile_interval_duration = if no_reconcile { + None + } else { + Some(Duration::from_secs(reconcile_interval)) + }; + + let config = DaemonConfig { + sync_interval: Duration::from_secs(interval), + reconcile_interval: reconcile_interval_duration, + state_path, + batch_size, + tables: tables.unwrap_or_default(), + schema, + }; + + tracing::info!("Sync interval: {}s", interval); + if let Some(ref ri) = config.reconcile_interval { + tracing::info!("Reconcile interval: {}s", ri.as_secs()); + } else { + tracing::info!("Reconciliation disabled"); + } + tracing::info!("Batch size: {}", batch_size); + tracing::info!("State file: {:?}", config.state_path); + + // Create the daemon + let daemon = SyncDaemon::new(source.clone(), target.clone(), config); + + if once { + // Run a single sync cycle + tracing::info!("Running single sync cycle..."); + + let stats = daemon.run_sync_cycle().await?; + + tracing::info!("Sync cycle complete:"); + tracing::info!(" Tables synced: {}", stats.tables_synced); + tracing::info!(" Rows synced: {}", stats.rows_synced); + if !stats.errors.is_empty() { + tracing::warn!(" Errors: {}", stats.errors.len()); + for err in &stats.errors { + tracing::warn!(" - {}", err); + } + } + + println!(); + println!("========================================"); + println!("Xmin sync cycle complete"); + println!("========================================"); + println!(" Tables synced: {}", stats.tables_synced); + println!(" Rows synced: {}", stats.rows_synced); + if !stats.errors.is_empty() { + println!(" Errors: {}", stats.errors.len()); + } + } else { + // Run continuous sync + tracing::info!("Starting continuous sync daemon..."); + tracing::info!("Press Ctrl+C to stop"); + + println!(); + println!("========================================"); + println!("Starting xmin-based continuous sync"); + println!("========================================"); + println!(" Sync interval: {}s", interval); + println!(" Press Ctrl+C to stop"); + println!(); + + // Create shutdown channel + let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1); + + // Set up Ctrl+C handler + let shutdown_tx_clone = shutdown_tx.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen for Ctrl+C"); + tracing::info!("Received shutdown signal"); + let _ = shutdown_tx_clone.send(()); + }); + + daemon.run(shutdown_rx).await?; + } + + Ok(()) +} diff --git a/src/xmin/daemon.rs b/src/xmin/daemon.rs new file mode 100644 index 0000000..fb010cb --- /dev/null +++ b/src/xmin/daemon.rs @@ -0,0 +1,398 @@ +// ABOUTME: SyncDaemon for xmin-based sync - orchestrates continuous replication +// ABOUTME: Runs sync cycles at configurable intervals with reconciliation + +use anyhow::{Context, Result}; +use std::path::PathBuf; +use std::time::Duration; +use tokio::time::interval; + +use super::reader::XminReader; +use super::reconciler::Reconciler; +use super::state::SyncState; +use super::writer::{get_primary_key_columns, get_table_columns, row_to_values, ChangeWriter}; + +/// Configuration for the SyncDaemon. +#[derive(Debug, Clone)] +pub struct DaemonConfig { + /// Interval between sync cycles + pub sync_interval: Duration, + /// Interval between reconciliation cycles (delete detection) + /// Set to None to disable reconciliation + pub reconcile_interval: Option, + /// Path to store sync state + pub state_path: PathBuf, + /// Maximum rows to process per batch + pub batch_size: usize, + /// Tables to sync (empty = all tables) + pub tables: Vec, + /// Schema to sync from + pub schema: String, +} + +impl Default for DaemonConfig { + fn default() -> Self { + Self { + sync_interval: Duration::from_secs(60), + reconcile_interval: Some(Duration::from_secs(3600)), // 1 hour + state_path: SyncState::default_path(), + batch_size: 1000, + tables: Vec::new(), + schema: "public".to_string(), + } + } +} + +/// Statistics from a sync cycle. +#[derive(Debug, Clone, Default)] +pub struct SyncStats { + pub tables_synced: usize, + pub rows_synced: u64, + pub rows_deleted: u64, + pub errors: Vec, + pub duration_ms: u64, +} + +impl SyncStats { + /// Check if the sync cycle completed without errors. + pub fn is_success(&self) -> bool { + self.errors.is_empty() + } +} + +/// SyncDaemon orchestrates continuous xmin-based replication. +/// +/// It runs periodic sync cycles that: +/// 1. Read changed rows from source using xmin +/// 2. Apply changes to target using upsert +/// 3. Periodically run reconciliation to detect deletes +/// 4. Persist sync state for resume capability +pub struct SyncDaemon { + config: DaemonConfig, + source_url: String, + target_url: String, +} + +impl SyncDaemon { + /// Create a new SyncDaemon with the given configuration. + pub fn new(source_url: String, target_url: String, config: DaemonConfig) -> Self { + Self { + config, + source_url, + target_url, + } + } + + /// Run a single sync cycle for all configured tables. + /// + /// This is the main entry point for synchronization. It: + /// 1. Loads or creates sync state + /// 2. Connects to source and target databases + /// 3. Syncs each table + /// 4. Saves updated state + pub async fn run_sync_cycle(&self) -> Result { + let start = std::time::Instant::now(); + let mut stats = SyncStats::default(); + + // Load or create sync state + let mut state = self.load_or_create_state().await?; + + // Connect to databases + let source_client = crate::postgres::connect_with_retry(&self.source_url) + .await + .context("Failed to connect to source database")?; + let target_client = crate::postgres::connect_with_retry(&self.target_url) + .await + .context("Failed to connect to target database")?; + + let reader = XminReader::new(&source_client); + let writer = ChangeWriter::new(&target_client); + + // Get tables to sync + let tables = if self.config.tables.is_empty() { + reader.list_tables(&self.config.schema).await? + } else { + self.config.tables.clone() + }; + + // Sync each table + for table in &tables { + match self + .sync_table(&reader, &writer, &mut state, &self.config.schema, table) + .await + { + Ok(rows) => { + stats.tables_synced += 1; + stats.rows_synced += rows; + } + Err(e) => { + let error_msg = format!("Failed to sync {}.{}: {}", self.config.schema, table, e); + tracing::error!("{}", error_msg); + stats.errors.push(error_msg); + } + } + } + + // Save state + state.save(&self.config.state_path).await?; + + stats.duration_ms = start.elapsed().as_millis() as u64; + Ok(stats) + } + + /// Run reconciliation to detect and delete orphaned rows. + pub async fn run_reconciliation(&self) -> Result { + let start = std::time::Instant::now(); + let mut stats = SyncStats::default(); + + // Connect to databases + let source_client = crate::postgres::connect_with_retry(&self.source_url) + .await + .context("Failed to connect to source database")?; + let target_client = crate::postgres::connect_with_retry(&self.target_url) + .await + .context("Failed to connect to target database")?; + + let reconciler = Reconciler::new(&source_client, &target_client); + let reader = XminReader::new(&source_client); + + // Get tables to reconcile + let tables = if self.config.tables.is_empty() { + reader.list_tables(&self.config.schema).await? + } else { + self.config.tables.clone() + }; + + // Reconcile each table + for table in &tables { + // Get primary key columns + let pk_columns = reader.get_primary_key(&self.config.schema, table).await?; + if pk_columns.is_empty() { + tracing::warn!( + "Skipping reconciliation for {}.{}: no primary key", + self.config.schema, + table + ); + continue; + } + + match reconciler + .reconcile_table(&self.config.schema, table, &pk_columns) + .await + { + Ok(deleted) => { + stats.tables_synced += 1; + stats.rows_deleted += deleted; + } + Err(e) => { + let error_msg = format!( + "Failed to reconcile {}.{}: {}", + self.config.schema, table, e + ); + tracing::error!("{}", error_msg); + stats.errors.push(error_msg); + } + } + } + + stats.duration_ms = start.elapsed().as_millis() as u64; + Ok(stats) + } + + /// Run the daemon continuously until stopped. + /// + /// This starts the main loop that runs sync cycles at the configured interval. + /// Reconciliation runs at its own interval if configured. + pub async fn run(&self, mut shutdown: tokio::sync::broadcast::Receiver<()>) -> Result<()> { + let mut sync_interval = interval(self.config.sync_interval); + let mut reconcile_interval = self + .config + .reconcile_interval + .map(|d| interval(d)); + + let mut cycles = 0u64; + let mut reconcile_cycles = 0u64; + + tracing::info!( + "Starting SyncDaemon with sync_interval={:?}, reconcile_interval={:?}", + self.config.sync_interval, + self.config.reconcile_interval + ); + + loop { + tokio::select! { + _ = sync_interval.tick() => { + cycles += 1; + tracing::info!("Starting sync cycle {}", cycles); + + match self.run_sync_cycle().await { + Ok(stats) => { + tracing::info!( + "Sync cycle {} completed: {} tables, {} rows in {}ms", + cycles, + stats.tables_synced, + stats.rows_synced, + stats.duration_ms + ); + if !stats.errors.is_empty() { + tracing::warn!("Sync cycle had {} errors", stats.errors.len()); + } + } + Err(e) => { + tracing::error!("Sync cycle {} failed: {}", cycles, e); + } + } + } + _ = async { + if let Some(ref mut interval) = reconcile_interval { + interval.tick().await + } else { + std::future::pending::().await + } + } => { + reconcile_cycles += 1; + tracing::info!("Starting reconciliation cycle {}", reconcile_cycles); + + match self.run_reconciliation().await { + Ok(stats) => { + tracing::info!( + "Reconciliation cycle {} completed: {} tables, {} rows deleted in {}ms", + reconcile_cycles, + stats.tables_synced, + stats.rows_deleted, + stats.duration_ms + ); + } + Err(e) => { + tracing::error!("Reconciliation cycle {} failed: {}", reconcile_cycles, e); + } + } + } + _ = shutdown.recv() => { + tracing::info!("Shutdown signal received, stopping SyncDaemon"); + break; + } + } + } + + Ok(()) + } + + /// Sync a single table. + async fn sync_table( + &self, + reader: &XminReader<'_>, + writer: &ChangeWriter<'_>, + state: &mut SyncState, + schema: &str, + table: &str, + ) -> Result { + // Get table state + let table_state = state.get_or_create_table(schema, table); + let since_xmin = table_state.last_xmin; + + // Get table metadata + let columns = get_table_columns(writer.client(), schema, table).await?; + let pk_columns = get_primary_key_columns(writer.client(), schema, table).await?; + + if pk_columns.is_empty() { + anyhow::bail!("Table {}.{} has no primary key", schema, table); + } + + let column_names: Vec = columns.iter().map(|(name, _)| name.clone()).collect(); + + // Read changes + let (rows, max_xmin) = reader + .read_changes(schema, table, &column_names, since_xmin) + .await?; + + if rows.is_empty() { + tracing::debug!("No changes in {}.{} since xmin {}", schema, table, since_xmin); + return Ok(0); + } + + tracing::info!( + "Found {} changed rows in {}.{} (xmin {} -> {})", + rows.len(), + schema, + table, + since_xmin, + max_xmin + ); + + // Convert rows to values (excluding the _xmin column we added) + let values: Vec>> = rows + .iter() + .map(|row| row_to_values(row, &columns)) + .collect(); + + // Apply changes + let affected = writer + .apply_batch(schema, table, &pk_columns, &column_names, values) + .await?; + + // Update state + state.update_table(schema, table, max_xmin, affected); + + Ok(affected) + } + + /// Load existing state or create new state. + async fn load_or_create_state(&self) -> Result { + if self.config.state_path.exists() { + match SyncState::load(&self.config.state_path).await { + Ok(state) => { + tracing::info!("Loaded existing sync state from {:?}", self.config.state_path); + return Ok(state); + } + Err(e) => { + tracing::warn!( + "Failed to load sync state from {:?}: {}. Creating new state.", + self.config.state_path, + e + ); + } + } + } + + tracing::info!("Creating new sync state"); + Ok(SyncState::new(&self.source_url, &self.target_url)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_daemon_config_default() { + let config = DaemonConfig::default(); + assert_eq!(config.sync_interval, Duration::from_secs(60)); + assert_eq!(config.reconcile_interval, Some(Duration::from_secs(3600))); + assert_eq!(config.batch_size, 1000); + assert_eq!(config.schema, "public"); + } + + #[test] + fn test_sync_stats_success() { + let stats = SyncStats { + tables_synced: 5, + rows_synced: 100, + rows_deleted: 0, + errors: vec![], + duration_ms: 500, + }; + assert!(stats.is_success()); + } + + #[test] + fn test_sync_stats_with_errors() { + let stats = SyncStats { + tables_synced: 4, + rows_synced: 80, + rows_deleted: 0, + errors: vec!["Failed to sync table X".to_string()], + duration_ms: 500, + }; + assert!(!stats.is_success()); + } +} diff --git a/src/xmin/mod.rs b/src/xmin/mod.rs new file mode 100644 index 0000000..37cd2fa --- /dev/null +++ b/src/xmin/mod.rs @@ -0,0 +1,14 @@ +// ABOUTME: xmin-based sync module for incremental PostgreSQL replication +// ABOUTME: Provides change detection using PostgreSQL's xmin system column + +pub mod daemon; +pub mod reader; +pub mod reconciler; +pub mod state; +pub mod writer; + +pub use daemon::{DaemonConfig, SyncDaemon, SyncStats}; +pub use reader::{BatchReader, ColumnInfo, XminReader}; +pub use reconciler::{ReconcileConfig, ReconcileResult, Reconciler}; +pub use state::{SyncState, TableSyncState}; +pub use writer::{get_primary_key_columns, get_table_columns, row_to_values, ChangeWriter}; diff --git a/src/xmin/reader.rs b/src/xmin/reader.rs new file mode 100644 index 0000000..750d239 --- /dev/null +++ b/src/xmin/reader.rs @@ -0,0 +1,324 @@ +// ABOUTME: XminReader for xmin-based sync - reads changed rows from source PostgreSQL +// ABOUTME: Uses xmin system column to detect rows modified since last sync + +use anyhow::{Context, Result}; +use tokio_postgres::{Client, Row}; + +/// Reads changed rows from a PostgreSQL table using xmin-based change detection. +/// +/// PostgreSQL's `xmin` system column contains the transaction ID that last modified +/// each row. By tracking the maximum xmin seen, we can query for only rows that +/// have been modified since the last sync. +/// +/// Note: xmin wraps around at 2^32 transactions, so this method is suitable for +/// incremental syncs but not for long-term archival purposes. +pub struct XminReader<'a> { + client: &'a Client, +} + +impl<'a> XminReader<'a> { + /// Create a new XminReader for the given PostgreSQL client connection. + pub fn new(client: &'a Client) -> Self { + Self { client } + } + + /// Get the current transaction ID (xmin snapshot) from the database. + /// + /// This should be called at the start of a sync to establish the high-water mark. + pub async fn get_current_xmin(&self) -> Result { + let row = self + .client + .query_one("SELECT txid_current()::text::bigint", &[]) + .await + .context("Failed to get current transaction ID")?; + + let txid: i64 = row.get(0); + // xmin is stored as u32, txid_current() returns i64 + // We mask to get the 32-bit xmin value + Ok((txid & 0xFFFFFFFF) as u32) + } + + /// Read all rows from a table that have xmin greater than the given value. + /// + /// # Arguments + /// + /// * `schema` - The schema name (e.g., "public") + /// * `table` - The table name + /// * `columns` - Column names to select (pass empty slice to select all) + /// * `since_xmin` - Only return rows with xmin > this value (0 = all rows) + /// + /// # Returns + /// + /// A tuple of (rows, max_xmin) where max_xmin is the highest xmin seen in the result set. + pub async fn read_changes( + &self, + schema: &str, + table: &str, + columns: &[String], + since_xmin: u32, + ) -> Result<(Vec, u32)> { + let column_list = if columns.is_empty() { + "*".to_string() + } else { + columns + .iter() + .map(|c| format!("\"{}\"", c)) + .collect::>() + .join(", ") + }; + + // Query rows where xmin > since_xmin, including the xmin value + let query = format!( + "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1 ORDER BY xmin", + column_list, schema, table + ); + + let rows = self + .client + .query(&query, &[&(since_xmin as i64)]) + .await + .with_context(|| format!("Failed to read changes from {}.{}", schema, table))?; + + // Find the max xmin in the result set + let max_xmin = rows + .iter() + .map(|row| { + let xmin: i64 = row.get("_xmin"); + (xmin & 0xFFFFFFFF) as u32 + }) + .max() + .unwrap_or(since_xmin); + + Ok((rows, max_xmin)) + } + + /// Read changes in batches to handle large tables efficiently. + /// + /// # Arguments + /// + /// * `schema` - The schema name + /// * `table` - The table name + /// * `columns` - Column names to select + /// * `since_xmin` - Only return rows with xmin > this value + /// * `batch_size` - Maximum rows per batch + /// + /// # Returns + /// + /// An iterator-like struct that yields batches of rows. + pub async fn read_changes_batched( + &self, + schema: &str, + table: &str, + columns: &[String], + since_xmin: u32, + batch_size: usize, + ) -> Result { + Ok(BatchReader { + schema: schema.to_string(), + table: table.to_string(), + columns: columns.to_vec(), + current_xmin: since_xmin, + batch_size, + exhausted: false, + }) + } + + /// Execute a batched read query and return the next batch. + pub async fn fetch_batch(&self, batch_reader: &mut BatchReader) -> Result, u32)>> { + if batch_reader.exhausted { + return Ok(None); + } + + let column_list = if batch_reader.columns.is_empty() { + "*".to_string() + } else { + batch_reader + .columns + .iter() + .map(|c| format!("\"{}\"", c)) + .collect::>() + .join(", ") + }; + + let query = format!( + "SELECT {}, xmin::text::bigint as _xmin FROM \"{}\".\"{}\" \ + WHERE xmin::text::bigint > $1 \ + ORDER BY xmin \ + LIMIT $2", + column_list, batch_reader.schema, batch_reader.table + ); + + let rows = self + .client + .query( + &query, + &[&(batch_reader.current_xmin as i64), &(batch_reader.batch_size as i64)], + ) + .await + .with_context(|| { + format!( + "Failed to read batch from {}.{}", + batch_reader.schema, batch_reader.table + ) + })?; + + if rows.is_empty() { + batch_reader.exhausted = true; + return Ok(None); + } + + // Update current_xmin to the max in this batch + let max_xmin = rows + .iter() + .map(|row| { + let xmin: i64 = row.get("_xmin"); + (xmin & 0xFFFFFFFF) as u32 + }) + .max() + .unwrap_or(batch_reader.current_xmin); + + // Mark as exhausted if we got fewer rows than batch_size + if rows.len() < batch_reader.batch_size { + batch_reader.exhausted = true; + } + + batch_reader.current_xmin = max_xmin; + + Ok(Some((rows, max_xmin))) + } + + /// Get the estimated row count for changes since a given xmin. + /// + /// This uses EXPLAIN to estimate without actually scanning the table. + pub async fn estimate_changes(&self, schema: &str, table: &str, since_xmin: u32) -> Result { + let query = format!( + "SELECT COUNT(*) FROM \"{}\".\"{}\" WHERE xmin::text::bigint > $1", + schema, table + ); + + let row = self + .client + .query_one(&query, &[&(since_xmin as i64)]) + .await + .with_context(|| format!("Failed to count changes in {}.{}", schema, table))?; + + let count: i64 = row.get(0); + Ok(count) + } + + /// Get list of all tables in a schema. + pub async fn list_tables(&self, schema: &str) -> Result> { + let rows = self + .client + .query( + "SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename", + &[&schema], + ) + .await + .with_context(|| format!("Failed to list tables in schema {}", schema))?; + + Ok(rows.iter().map(|row| row.get(0)).collect()) + } + + /// Get column information for a table. + pub async fn get_columns(&self, schema: &str, table: &str) -> Result> { + let rows = self + .client + .query( + "SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position", + &[&schema, &table], + ) + .await + .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?; + + Ok(rows + .iter() + .map(|row| ColumnInfo { + name: row.get(0), + data_type: row.get(1), + is_nullable: row.get::<_, String>(2) == "YES", + has_default: row.get::<_, Option>(3).is_some(), + }) + .collect()) + } + + /// Get primary key columns for a table. + pub async fn get_primary_key(&self, schema: &str, table: &str) -> Result> { + let rows = self + .client + .query( + "SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + JOIN pg_class c ON c.oid = i.indrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE i.indisprimary + AND n.nspname = $1 + AND c.relname = $2 + ORDER BY array_position(i.indkey, a.attnum)", + &[&schema, &table], + ) + .await + .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?; + + Ok(rows.iter().map(|row| row.get(0)).collect()) + } +} + +/// Batch reader state for iterating over large result sets. +pub struct BatchReader { + pub schema: String, + pub table: String, + pub columns: Vec, + pub current_xmin: u32, + pub batch_size: usize, + pub exhausted: bool, +} + +/// Information about a table column. +#[derive(Debug, Clone)] +pub struct ColumnInfo { + pub name: String, + pub data_type: String, + pub is_nullable: bool, + pub has_default: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_batch_reader_initial_state() { + let reader = BatchReader { + schema: "public".to_string(), + table: "users".to_string(), + columns: vec!["id".to_string(), "name".to_string()], + current_xmin: 0, + batch_size: 1000, + exhausted: false, + }; + + assert_eq!(reader.schema, "public"); + assert_eq!(reader.table, "users"); + assert_eq!(reader.current_xmin, 0); + assert!(!reader.exhausted); + } + + #[test] + fn test_column_info() { + let col = ColumnInfo { + name: "id".to_string(), + data_type: "integer".to_string(), + is_nullable: false, + has_default: true, + }; + + assert_eq!(col.name, "id"); + assert!(!col.is_nullable); + assert!(col.has_default); + } +} diff --git a/src/xmin/reconciler.rs b/src/xmin/reconciler.rs new file mode 100644 index 0000000..f33314d --- /dev/null +++ b/src/xmin/reconciler.rs @@ -0,0 +1,278 @@ +// ABOUTME: Reconciler for xmin-based sync - detects deleted rows in source +// ABOUTME: Compares primary keys between source and target to find orphaned rows + +use anyhow::{Context, Result}; +use std::collections::HashSet; +use tokio_postgres::types::ToSql; +use tokio_postgres::Client; + +use super::writer::ChangeWriter; + +/// Reconciler detects rows that exist in target but not in source (deletions). +/// +/// Since xmin-based sync only sees modified rows, it cannot detect deletions. +/// The Reconciler performs periodic full-table primary key comparisons to find +/// rows that need to be deleted from the target. +pub struct Reconciler<'a> { + source_client: &'a Client, + target_client: &'a Client, +} + +impl<'a> Reconciler<'a> { + /// Create a new Reconciler with source and target database connections. + pub fn new(source_client: &'a Client, target_client: &'a Client) -> Self { + Self { + source_client, + target_client, + } + } + + /// Find rows that exist in target but not in source (orphaned rows). + /// + /// This performs a primary key comparison between source and target tables. + /// Returns the primary key values of rows that should be deleted from target. + /// + /// # Arguments + /// + /// * `schema` - Schema name + /// * `table` - Table name + /// * `primary_key_columns` - Primary key column names + /// + /// # Returns + /// + /// A vector of primary key value tuples for orphaned rows. + pub async fn find_orphaned_rows( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + ) -> Result>> { + // Get all PKs from source + let source_pks = self + .get_all_primary_keys(self.source_client, schema, table, primary_key_columns) + .await + .context("Failed to get source primary keys")?; + + // Get all PKs from target + let target_pks = self + .get_all_primary_keys(self.target_client, schema, table, primary_key_columns) + .await + .context("Failed to get target primary keys")?; + + // Find PKs in target that don't exist in source + let source_set: HashSet> = source_pks.into_iter().collect(); + let orphaned: Vec> = target_pks + .into_iter() + .filter(|pk| !source_set.contains(pk)) + .collect(); + + tracing::info!( + "Found {} orphaned rows in {}.{} that need deletion", + orphaned.len(), + schema, + table + ); + + Ok(orphaned) + } + + /// Reconcile a table by deleting orphaned rows from target. + /// + /// This is a convenience method that finds orphaned rows and deletes them. + /// + /// # Returns + /// + /// The number of rows deleted from target. + pub async fn reconcile_table( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + ) -> Result { + let orphaned = self + .find_orphaned_rows(schema, table, primary_key_columns) + .await?; + + if orphaned.is_empty() { + tracing::info!("No orphaned rows found in {}.{}", schema, table); + return Ok(0); + } + + // Convert string PKs to ToSql values + let pk_values: Vec>> = orphaned + .into_iter() + .map(|pk| { + pk.into_iter() + .map(|v| Box::new(v) as Box) + .collect() + }) + .collect(); + + // Delete orphaned rows + let writer = ChangeWriter::new(self.target_client); + let deleted = writer + .delete_rows(schema, table, primary_key_columns, pk_values) + .await?; + + tracing::info!("Deleted {} orphaned rows from {}.{}", deleted, schema, table); + + Ok(deleted) + } + + /// Get all primary key values from a table. + async fn get_all_primary_keys( + &self, + client: &Client, + schema: &str, + table: &str, + primary_key_columns: &[String], + ) -> Result>> { + let pk_cols: Vec = primary_key_columns + .iter() + .map(|c| format!("\"{}\"::text", c)) + .collect(); + + let query = format!( + "SELECT {} FROM \"{}\".\"{}\" ORDER BY {}", + pk_cols.join(", "), + schema, + table, + primary_key_columns + .iter() + .map(|c| format!("\"{}\"", c)) + .collect::>() + .join(", ") + ); + + let rows = client + .query(&query, &[]) + .await + .with_context(|| format!("Failed to get primary keys from {}.{}", schema, table))?; + + let pks: Vec> = rows + .iter() + .map(|row| { + (0..primary_key_columns.len()) + .map(|i| row.get::<_, String>(i)) + .collect() + }) + .collect(); + + Ok(pks) + } + + /// Get count of rows in source and target for comparison. + pub async fn get_row_counts(&self, schema: &str, table: &str) -> Result<(i64, i64)> { + let query = format!("SELECT COUNT(*) FROM \"{}\".\"{}\"", schema, table); + + let source_row = self + .source_client + .query_one(&query, &[]) + .await + .context("Failed to get source row count")?; + let source_count: i64 = source_row.get(0); + + let target_row = self + .target_client + .query_one(&query, &[]) + .await + .context("Failed to get target row count")?; + let target_count: i64 = target_row.get(0); + + Ok((source_count, target_count)) + } + + /// Check if a table exists in the target database. + pub async fn table_exists_in_target(&self, schema: &str, table: &str) -> Result { + let query = "SELECT EXISTS ( + SELECT 1 FROM information_schema.tables + WHERE table_schema = $1 AND table_name = $2 + )"; + + let row = self + .target_client + .query_one(query, &[&schema, &table]) + .await + .context("Failed to check if table exists")?; + + Ok(row.get(0)) + } +} + +/// Configuration for reconciliation behavior. +#[derive(Debug, Clone)] +pub struct ReconcileConfig { + /// Whether to actually delete orphaned rows (false = dry run) + pub delete_orphans: bool, + /// Maximum number of orphans to delete in one batch + pub max_deletes: Option, + /// Tables to skip during reconciliation + pub skip_tables: Vec, +} + +impl Default for ReconcileConfig { + fn default() -> Self { + Self { + delete_orphans: true, + max_deletes: None, + skip_tables: Vec::new(), + } + } +} + +/// Result of a reconciliation operation. +#[derive(Debug, Clone)] +pub struct ReconcileResult { + pub schema: String, + pub table: String, + pub source_count: i64, + pub target_count: i64, + pub orphaned_count: usize, + pub deleted_count: u64, +} + +impl ReconcileResult { + /// Check if the table is in sync (same row count, no orphans). + pub fn is_in_sync(&self) -> bool { + self.source_count == self.target_count && self.orphaned_count == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_reconcile_config_default() { + let config = ReconcileConfig::default(); + assert!(config.delete_orphans); + assert!(config.max_deletes.is_none()); + assert!(config.skip_tables.is_empty()); + } + + #[test] + fn test_reconcile_result_in_sync() { + let result = ReconcileResult { + schema: "public".to_string(), + table: "users".to_string(), + source_count: 100, + target_count: 100, + orphaned_count: 0, + deleted_count: 0, + }; + assert!(result.is_in_sync()); + } + + #[test] + fn test_reconcile_result_not_in_sync() { + let result = ReconcileResult { + schema: "public".to_string(), + table: "users".to_string(), + source_count: 100, + target_count: 105, + orphaned_count: 5, + deleted_count: 0, + }; + assert!(!result.is_in_sync()); + } +} diff --git a/src/xmin/state.rs b/src/xmin/state.rs new file mode 100644 index 0000000..3e94909 --- /dev/null +++ b/src/xmin/state.rs @@ -0,0 +1,253 @@ +// ABOUTME: SyncState for xmin-based sync - tracks sync progress per table +// ABOUTME: Persists high-water mark xmin values to enable incremental syncs + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; +use tokio::fs; + +/// Sync state for a single table, tracking the last synced xmin value. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TableSyncState { + /// Schema name (e.g., "public") + pub schema: String, + /// Table name + pub table: String, + /// Last successfully synced xmin value (high-water mark) + /// Rows with xmin > this value need to be synced + pub last_xmin: u32, + /// Timestamp of last successful sync + pub last_sync_at: chrono::DateTime, + /// Number of rows synced in last batch + pub last_row_count: u64, +} + +impl TableSyncState { + /// Create a new TableSyncState with initial xmin of 0 (sync everything) + pub fn new(schema: &str, table: &str) -> Self { + Self { + schema: schema.to_string(), + table: table.to_string(), + last_xmin: 0, + last_sync_at: chrono::Utc::now(), + last_row_count: 0, + } + } + + /// Update state after a successful sync + pub fn update(&mut self, new_xmin: u32, row_count: u64) { + self.last_xmin = new_xmin; + self.last_sync_at = chrono::Utc::now(); + self.last_row_count = row_count; + } + + /// Get the qualified table name (schema.table) + pub fn qualified_name(&self) -> String { + format!("{}.{}", self.schema, self.table) + } +} + +/// Overall sync state for a database, containing state for all tracked tables. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncState { + /// Source database URL (sanitized - no password) + pub source_url: String, + /// Target database URL (sanitized - no password) + pub target_url: String, + /// Per-table sync states, keyed by "schema.table" + pub tables: HashMap, + /// Version of the state format for future migrations + pub version: u32, + /// When this state was created + pub created_at: chrono::DateTime, + /// When this state was last modified + pub updated_at: chrono::DateTime, +} + +impl SyncState { + /// Create a new empty SyncState + pub fn new(source_url: &str, target_url: &str) -> Self { + let now = chrono::Utc::now(); + Self { + source_url: sanitize_url(source_url), + target_url: sanitize_url(target_url), + tables: HashMap::new(), + version: 1, + created_at: now, + updated_at: now, + } + } + + /// Get or create state for a table + pub fn get_or_create_table(&mut self, schema: &str, table: &str) -> &mut TableSyncState { + let key = format!("{}.{}", schema, table); + self.tables + .entry(key) + .or_insert_with(|| TableSyncState::new(schema, table)) + } + + /// Get state for a table if it exists + pub fn get_table(&self, schema: &str, table: &str) -> Option<&TableSyncState> { + let key = format!("{}.{}", schema, table); + self.tables.get(&key) + } + + /// Update state for a table after successful sync + pub fn update_table(&mut self, schema: &str, table: &str, new_xmin: u32, row_count: u64) { + let state = self.get_or_create_table(schema, table); + state.update(new_xmin, row_count); + self.updated_at = chrono::Utc::now(); + } + + /// Remove state for a table (e.g., if table was dropped) + pub fn remove_table(&mut self, schema: &str, table: &str) -> Option { + let key = format!("{}.{}", schema, table); + let removed = self.tables.remove(&key); + if removed.is_some() { + self.updated_at = chrono::Utc::now(); + } + removed + } + + /// Get all table names being tracked + pub fn tracked_tables(&self) -> Vec<&str> { + self.tables.keys().map(|s| s.as_str()).collect() + } + + /// Load state from a JSON file + pub async fn load(path: &Path) -> Result { + let contents = fs::read_to_string(path) + .await + .with_context(|| format!("Failed to read sync state from {:?}", path))?; + let state: SyncState = serde_json::from_str(&contents) + .with_context(|| format!("Failed to parse sync state from {:?}", path))?; + Ok(state) + } + + /// Save state to a JSON file + pub async fn save(&self, path: &Path) -> Result<()> { + // Ensure parent directory exists + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create directory {:?}", parent))?; + } + + let contents = serde_json::to_string_pretty(self) + .context("Failed to serialize sync state")?; + fs::write(path, contents) + .await + .with_context(|| format!("Failed to write sync state to {:?}", path))?; + Ok(()) + } + + /// Get the default state file path for the current directory + pub fn default_path() -> std::path::PathBuf { + std::path::PathBuf::from(".seren-replicator/xmin-sync-state.json") + } +} + +/// Sanitize a database URL by removing the password component +fn sanitize_url(url: &str) -> String { + // Try to parse as URL and remove password + if let Ok(mut parsed) = url::Url::parse(url) { + if parsed.password().is_some() { + let _ = parsed.set_password(Some("***")); + } + parsed.to_string() + } else { + // If not a valid URL, return as-is (might be a file path for SQLite) + url.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_table_sync_state_new() { + let state = TableSyncState::new("public", "users"); + assert_eq!(state.schema, "public"); + assert_eq!(state.table, "users"); + assert_eq!(state.last_xmin, 0); + assert_eq!(state.last_row_count, 0); + } + + #[test] + fn test_table_sync_state_update() { + let mut state = TableSyncState::new("public", "users"); + state.update(12345, 100); + assert_eq!(state.last_xmin, 12345); + assert_eq!(state.last_row_count, 100); + } + + #[test] + fn test_table_sync_state_qualified_name() { + let state = TableSyncState::new("myschema", "mytable"); + assert_eq!(state.qualified_name(), "myschema.mytable"); + } + + #[test] + fn test_sync_state_new() { + let state = SyncState::new( + "postgresql://user:pass@localhost/db", + "postgresql://user:pass@remote/db", + ); + assert!(state.tables.is_empty()); + assert_eq!(state.version, 1); + // Passwords should be sanitized + assert!(state.source_url.contains("***")); + assert!(state.target_url.contains("***")); + } + + #[test] + fn test_sync_state_get_or_create() { + let mut state = SyncState::new("source", "target"); + + // First call creates + let table_state = state.get_or_create_table("public", "users"); + assert_eq!(table_state.last_xmin, 0); + + // Update it + table_state.update(100, 50); + + // Second call retrieves existing + let table_state = state.get_or_create_table("public", "users"); + assert_eq!(table_state.last_xmin, 100); + } + + #[test] + fn test_sync_state_update_table() { + let mut state = SyncState::new("source", "target"); + state.update_table("public", "users", 500, 200); + + let table_state = state.get_table("public", "users").unwrap(); + assert_eq!(table_state.last_xmin, 500); + assert_eq!(table_state.last_row_count, 200); + } + + #[test] + fn test_sync_state_remove_table() { + let mut state = SyncState::new("source", "target"); + state.update_table("public", "users", 100, 10); + + let removed = state.remove_table("public", "users"); + assert!(removed.is_some()); + assert!(state.get_table("public", "users").is_none()); + } + + #[test] + fn test_sanitize_url() { + assert_eq!( + sanitize_url("postgresql://user:secret@localhost/db"), + "postgresql://user:***@localhost/db" + ); + assert_eq!( + sanitize_url("postgresql://user@localhost/db"), + "postgresql://user@localhost/db" + ); + assert_eq!(sanitize_url("/path/to/db.sqlite"), "/path/to/db.sqlite"); + } +} diff --git a/src/xmin/writer.rs b/src/xmin/writer.rs new file mode 100644 index 0000000..8d1033e --- /dev/null +++ b/src/xmin/writer.rs @@ -0,0 +1,535 @@ +// ABOUTME: ChangeWriter for xmin-based sync - applies changes to target PostgreSQL +// ABOUTME: Uses INSERT ... ON CONFLICT DO UPDATE for efficient upserts + +use anyhow::{Context, Result}; +use tokio_postgres::types::ToSql; +use tokio_postgres::{Client, Row}; + +/// Writes changes to the target PostgreSQL database using upsert operations. +/// +/// The ChangeWriter handles batched upserts within transactions for efficiency +/// and atomicity. It dynamically builds INSERT ... ON CONFLICT DO UPDATE queries +/// based on table schema. +pub struct ChangeWriter<'a> { + client: &'a Client, +} + +impl<'a> ChangeWriter<'a> { + /// Create a new ChangeWriter for the given PostgreSQL client connection. + pub fn new(client: &'a Client) -> Self { + Self { client } + } + + /// Get a reference to the underlying client. + /// + /// Useful for callers that need to perform additional queries. + pub fn client(&self) -> &Client { + self.client + } + + /// Apply a batch of rows to a table using upsert (INSERT ... ON CONFLICT DO UPDATE). + /// + /// Uses batching internally to stay within PostgreSQL's parameter limits. + /// Each batch is executed as a separate query (PostgreSQL auto-commits). + /// + /// # Arguments + /// + /// * `schema` - The schema name (e.g., "public") + /// * `table` - The table name + /// * `primary_key_columns` - Column names that form the primary key + /// * `all_columns` - All column names in the order they appear in `rows` + /// * `rows` - The rows to upsert, each row is a vector of values + /// + /// # Returns + /// + /// The number of rows affected. + pub async fn apply_batch( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + all_columns: &[String], + rows: Vec>>, + ) -> Result { + if rows.is_empty() { + return Ok(0); + } + + // PostgreSQL has a limit of ~65535 parameters per query + // Calculate batch size based on number of columns + let params_per_row = all_columns.len(); + let max_params = 65000; // Leave some margin + let batch_size = std::cmp::max(1, max_params / params_per_row); + + let mut total_affected = 0u64; + + for chunk in rows.chunks(batch_size) { + let affected = self + .execute_upsert_batch(schema, table, primary_key_columns, all_columns, chunk) + .await?; + total_affected += affected; + } + + Ok(total_affected) + } + + /// Execute a single batch of upserts. + async fn execute_upsert_batch( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + all_columns: &[String], + rows: &[Vec>], + ) -> Result { + if rows.is_empty() { + return Ok(0); + } + + let query = build_upsert_query(schema, table, primary_key_columns, all_columns, rows.len()); + + // Flatten all row values into a single params vector + let params: Vec<&(dyn ToSql + Sync)> = rows + .iter() + .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync))) + .collect(); + + let affected = self + .client + .execute(&query, ¶ms) + .await + .with_context(|| format!("Failed to upsert batch into {}.{}", schema, table))?; + + Ok(affected) + } + + /// Apply a single row using upsert. + /// + /// For single rows, this is more efficient than creating a batch. + pub async fn apply_row( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + all_columns: &[String], + values: Vec>, + ) -> Result { + let query = build_upsert_query(schema, table, primary_key_columns, all_columns, 1); + + let params: Vec<&(dyn ToSql + Sync)> = values + .iter() + .map(|v| v.as_ref() as &(dyn ToSql + Sync)) + .collect(); + + let affected = self + .client + .execute(&query, ¶ms) + .await + .with_context(|| format!("Failed to upsert row into {}.{}", schema, table))?; + + Ok(affected) + } + + /// Delete rows by primary key values. + /// + /// Used by the reconciler to remove rows that no longer exist in source. + /// Executes deletes in batches to stay within PostgreSQL parameter limits. + pub async fn delete_rows( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + pk_values: Vec>>, + ) -> Result { + if pk_values.is_empty() { + return Ok(0); + } + + let mut total_deleted = 0u64; + + // Delete in batches + let batch_size = 1000; + for chunk in pk_values.chunks(batch_size) { + let deleted = self + .execute_delete_batch(schema, table, primary_key_columns, chunk) + .await?; + total_deleted += deleted; + } + + Ok(total_deleted) + } + + /// Execute a batch delete. + async fn execute_delete_batch( + &self, + schema: &str, + table: &str, + primary_key_columns: &[String], + pk_values: &[Vec>], + ) -> Result { + if pk_values.is_empty() { + return Ok(0); + } + + let query = build_delete_query(schema, table, primary_key_columns, pk_values.len()); + + let params: Vec<&(dyn ToSql + Sync)> = pk_values + .iter() + .flat_map(|row| row.iter().map(|v| v.as_ref() as &(dyn ToSql + Sync))) + .collect(); + + let deleted = self + .client + .execute(&query, ¶ms) + .await + .with_context(|| format!("Failed to delete rows from {}.{}", schema, table))?; + + Ok(deleted) + } +} + +/// Build an upsert query for the given table schema and batch size. +/// +/// Generates a query like: +/// ```sql +/// INSERT INTO "schema"."table" ("col1", "col2", "col3") +/// VALUES ($1, $2, $3), ($4, $5, $6), ... +/// ON CONFLICT ("pk_col") DO UPDATE SET +/// "col2" = EXCLUDED."col2", +/// "col3" = EXCLUDED."col3" +/// ``` +fn build_upsert_query( + schema: &str, + table: &str, + primary_key_columns: &[String], + all_columns: &[String], + num_rows: usize, +) -> String { + // Quote identifiers to handle reserved words and special characters + let quoted_columns: Vec = all_columns.iter().map(|c| format!("\"{}\"", c)).collect(); + + let quoted_pk_columns: Vec = primary_key_columns + .iter() + .map(|c| format!("\"{}\"", c)) + .collect(); + + // Build VALUES placeholders: ($1, $2, $3), ($4, $5, $6), ... + let num_cols = all_columns.len(); + let value_rows: Vec = (0..num_rows) + .map(|row_idx| { + let placeholders: Vec = (0..num_cols) + .map(|col_idx| format!("${}", row_idx * num_cols + col_idx + 1)) + .collect(); + format!("({})", placeholders.join(", ")) + }) + .collect(); + + // Build UPDATE SET clause for non-PK columns + let update_columns: Vec = all_columns + .iter() + .filter(|c| !primary_key_columns.contains(c)) + .map(|c| format!("\"{}\" = EXCLUDED.\"{}\"", c, c)) + .collect(); + + let update_clause = if update_columns.is_empty() { + // All columns are PKs - use DO NOTHING + "DO NOTHING".to_string() + } else { + format!("DO UPDATE SET {}", update_columns.join(", ")) + }; + + format!( + "INSERT INTO \"{}\".\"{}\" ({}) VALUES {} ON CONFLICT ({}) {}", + schema, + table, + quoted_columns.join(", "), + value_rows.join(", "), + quoted_pk_columns.join(", "), + update_clause + ) +} + +/// Build a delete query for multiple rows by primary key. +/// +/// For single-column PK: +/// ```sql +/// DELETE FROM "schema"."table" WHERE "id" IN ($1, $2, $3, ...) +/// ``` +/// +/// For composite PK: +/// ```sql +/// DELETE FROM "schema"."table" WHERE ("pk1", "pk2") IN (($1, $2), ($3, $4), ...) +/// ``` +fn build_delete_query( + schema: &str, + table: &str, + primary_key_columns: &[String], + num_rows: usize, +) -> String { + let num_pk_cols = primary_key_columns.len(); + + if num_pk_cols == 1 { + // Simple case: single-column primary key + let pk_col = format!("\"{}\"", primary_key_columns[0]); + let placeholders: Vec = (1..=num_rows).map(|i| format!("${}", i)).collect(); + + format!( + "DELETE FROM \"{}\".\"{}\" WHERE {} IN ({})", + schema, + table, + pk_col, + placeholders.join(", ") + ) + } else { + // Composite primary key + let pk_cols: Vec = primary_key_columns + .iter() + .map(|c| format!("\"{}\"", c)) + .collect(); + + let value_tuples: Vec = (0..num_rows) + .map(|row_idx| { + let placeholders: Vec = (0..num_pk_cols) + .map(|col_idx| format!("${}", row_idx * num_pk_cols + col_idx + 1)) + .collect(); + format!("({})", placeholders.join(", ")) + }) + .collect(); + + format!( + "DELETE FROM \"{}\".\"{}\" WHERE ({}) IN ({})", + schema, + table, + pk_cols.join(", "), + value_tuples.join(", ") + ) + } +} + +/// Extract column metadata from a PostgreSQL table. +/// +/// Returns (column_name, data_type) pairs for all columns in the table. +pub async fn get_table_columns( + client: &Client, + schema: &str, + table: &str, +) -> Result> { + let rows = client + .query( + "SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position", + &[&schema, &table], + ) + .await + .with_context(|| format!("Failed to get columns for {}.{}", schema, table))?; + + Ok(rows + .iter() + .map(|row| { + let name: String = row.get(0); + let dtype: String = row.get(1); + (name, dtype) + }) + .collect()) +} + +/// Get primary key columns for a table. +/// +/// Returns the column names that form the primary key constraint. +pub async fn get_primary_key_columns( + client: &Client, + schema: &str, + table: &str, +) -> Result> { + let rows = client + .query( + "SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + JOIN pg_class c ON c.oid = i.indrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE i.indisprimary + AND n.nspname = $1 + AND c.relname = $2 + ORDER BY array_position(i.indkey, a.attnum)", + &[&schema, &table], + ) + .await + .with_context(|| format!("Failed to get primary key for {}.{}", schema, table))?; + + Ok(rows.iter().map(|row| row.get(0)).collect()) +} + +/// Convert a tokio_postgres Row to a vector of boxed ToSql values. +/// +/// This is a helper for extracting values from source rows to pass to ChangeWriter. +/// The caller must know the column types to extract values correctly. +pub fn row_to_values( + row: &Row, + column_types: &[(String, String)], +) -> Vec> { + column_types + .iter() + .enumerate() + .map(|(idx, (_name, dtype))| -> Box { + // Handle common PostgreSQL types + match dtype.as_str() { + "integer" | "int4" => { + let val: Option = row.get(idx); + Box::new(val) + } + "bigint" | "int8" => { + let val: Option = row.get(idx); + Box::new(val) + } + "smallint" | "int2" => { + let val: Option = row.get(idx); + Box::new(val) + } + "text" | "varchar" | "character varying" | "char" | "character" | "name" => { + let val: Option = row.get(idx); + Box::new(val) + } + "boolean" | "bool" => { + let val: Option = row.get(idx); + Box::new(val) + } + "real" | "float4" => { + let val: Option = row.get(idx); + Box::new(val) + } + "double precision" | "float8" => { + let val: Option = row.get(idx); + Box::new(val) + } + "uuid" => { + let val: Option = row.get(idx); + Box::new(val) + } + "timestamp without time zone" | "timestamp" => { + let val: Option = row.get(idx); + Box::new(val) + } + "timestamp with time zone" | "timestamptz" => { + let val: Option> = row.get(idx); + Box::new(val) + } + "date" => { + let val: Option = row.get(idx); + Box::new(val) + } + "json" | "jsonb" => { + let val: Option = row.get(idx); + Box::new(val) + } + "bytea" => { + let val: Option> = row.get(idx); + Box::new(val) + } + "numeric" | "decimal" => { + // Fall back to string representation + let val: Option = row.try_get::<_, String>(idx).ok(); + Box::new(val) + } + _ => { + // For unknown types, try to get as string + let val: Option = row.try_get::<_, String>(idx).ok(); + Box::new(val) + } + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_upsert_query_single_row() { + let query = build_upsert_query( + "public", + "users", + &["id".to_string()], + &["id".to_string(), "name".to_string(), "email".to_string()], + 1, + ); + + assert!(query.contains("INSERT INTO \"public\".\"users\"")); + assert!(query.contains("(\"id\", \"name\", \"email\")")); + assert!(query.contains("VALUES ($1, $2, $3)")); + assert!(query.contains("ON CONFLICT (\"id\")")); + assert!(query.contains("DO UPDATE SET")); + assert!(query.contains("\"name\" = EXCLUDED.\"name\"")); + assert!(query.contains("\"email\" = EXCLUDED.\"email\"")); + } + + #[test] + fn test_build_upsert_query_multiple_rows() { + let query = build_upsert_query( + "public", + "users", + &["id".to_string()], + &["id".to_string(), "name".to_string()], + 3, + ); + + assert!(query.contains("($1, $2), ($3, $4), ($5, $6)")); + } + + #[test] + fn test_build_upsert_query_composite_pk() { + let query = build_upsert_query( + "public", + "order_items", + &["order_id".to_string(), "item_id".to_string()], + &[ + "order_id".to_string(), + "item_id".to_string(), + "quantity".to_string(), + ], + 1, + ); + + assert!(query.contains("ON CONFLICT (\"order_id\", \"item_id\")")); + assert!(query.contains("\"quantity\" = EXCLUDED.\"quantity\"")); + } + + #[test] + fn test_build_upsert_query_all_pk_columns() { + // When all columns are PK columns, should use DO NOTHING + let query = build_upsert_query( + "public", + "tags", + &["id".to_string()], + &["id".to_string()], + 1, + ); + + assert!(query.contains("DO NOTHING")); + assert!(!query.contains("DO UPDATE SET")); + } + + #[test] + fn test_build_delete_query_single_pk() { + let query = build_delete_query("public", "users", &["id".to_string()], 3); + + assert!(query.contains("DELETE FROM \"public\".\"users\"")); + assert!(query.contains("WHERE \"id\" IN ($1, $2, $3)")); + } + + #[test] + fn test_build_delete_query_composite_pk() { + let query = build_delete_query( + "public", + "order_items", + &["order_id".to_string(), "item_id".to_string()], + 2, + ); + + assert!(query.contains("WHERE (\"order_id\", \"item_id\") IN")); + assert!(query.contains("($1, $2), ($3, $4)")); + } +}