Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/commands/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ pub async fn init(
.with_context(|| format!("Invalid database name: '{}'", db_info.name))?;

// Try to create database atomically (avoids TOCTOU vulnerability)
let create_query = format!("CREATE DATABASE \"{}\"", db_info.name);
let create_query = format!(
"CREATE DATABASE {}",
crate::utils::quote_ident(&db_info.name)
);
match target_client.execute(&create_query, &[]).await {
Ok(_) => {
tracing::info!(" Created database '{}'", db_info.name);
Expand Down Expand Up @@ -372,8 +375,10 @@ pub async fn init(
drop_database_if_exists(&target_client, &db_info.name).await?;

// Recreate the database
let create_query =
format!("CREATE DATABASE \"{}\"", db_info.name);
let create_query = format!(
"CREATE DATABASE {}",
crate::utils::quote_ident(&db_info.name)
);
target_client
.execute(&create_query, &[])
.await
Expand Down Expand Up @@ -666,7 +671,10 @@ async fn drop_database_if_exists(target_conn: &Client, db_name: &str) -> Result<
target_conn.execute(terminate_query, &[&db_name]).await?;

// Drop the database
let drop_query = format!("DROP DATABASE IF EXISTS \"{}\"", db_name);
let drop_query = format!(
"DROP DATABASE IF EXISTS {}",
crate::utils::quote_ident(db_name)
);
target_conn
.execute(&drop_query, &[])
.await
Expand Down
21 changes: 19 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ use database_replicator::commands;
#[command(about = "Universal database-to-PostgreSQL replication CLI", long_about = None)]
#[command(version)]
struct Cli {
/// Allow self-signed TLS certificates (insecure - use only for testing)
#[arg(
long = "allow-self-signed-certs",
global = true,
default_value_t = false
)]
allow_self_signed_certs: bool,
#[command(subcommand)]
command: Commands,
}
Expand Down Expand Up @@ -181,6 +188,9 @@ async fn main() -> anyhow::Result<()> {

let cli = Cli::parse();

// Initialize TLS policy using thread-safe OnceLock
database_replicator::postgres::connection::init_tls_policy(cli.allow_self_signed_certs);

match cli.command {
Commands::Validate {
source,
Expand Down Expand Up @@ -402,7 +412,7 @@ async fn init_remote(
drop_existing: bool,
no_sync: bool,
seren_api: String,
_job_timeout: u64,
job_timeout: u64,
) -> anyhow::Result<()> {
use database_replicator::migration;
use database_replicator::postgres;
Expand Down Expand Up @@ -466,6 +476,8 @@ async fn init_remote(
} else {
Some(FilterSpec {
include_databases,
exclude_databases,
include_tables,
exclude_tables,
})
};
Expand All @@ -481,7 +493,12 @@ async fn init_remote(
"estimated_size_bytes".to_string(),
serde_json::Value::Number(serde_json::Number::from(estimated_size_bytes)),
);
// Note: "yes" and "job_timeout" are client-side only options, not sent to server
// Optional timeout hint for remote orchestrator
options.insert(
"job_timeout_seconds".to_string(),
serde_json::Value::Number(serde_json::Number::from(job_timeout as i64)),
);
// Note: "yes" is client-side only, not sent to server

let job_spec = JobSpec {
version: "1.0".to_string(),
Expand Down
3 changes: 3 additions & 0 deletions src/migration/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
Duration::from_secs(1), // Start with 1 second delay
"pg_dumpall (dump globals)",
)
.await
.context(
"pg_dumpall failed to dump global objects.\n\
\n\
Expand Down Expand Up @@ -172,6 +173,7 @@ pub async fn dump_schema(
Duration::from_secs(1), // Start with 1 second delay
"pg_dump (dump schema)",
)
.await
.with_context(|| {
format!(
"pg_dump failed to dump schema for database '{}'.\n\
Expand Down Expand Up @@ -299,6 +301,7 @@ pub async fn dump_data(
Duration::from_secs(1), // Start with 1 second delay
"pg_dump (dump data)",
)
.await
.with_context(|| {
format!(
"pg_dump failed to dump data for database '{}'.\n\
Expand Down
5 changes: 4 additions & 1 deletion src/migration/restore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ pub async fn restore_globals(target_url: &str, input_path: &str) -> Result<()> {
3, // Max 3 retries
Duration::from_secs(1), // Start with 1 second delay
"psql (restore globals)",
);
)
.await;

// Handle result - don't fail on warnings for global objects
match result {
Expand Down Expand Up @@ -136,6 +137,7 @@ pub async fn restore_schema(target_url: &str, input_path: &str) -> Result<()> {
Duration::from_secs(1), // Start with 1 second delay
"psql (restore schema)",
)
.await
.context(
"Schema restoration failed.\n\
\n\
Expand Down Expand Up @@ -228,6 +230,7 @@ pub async fn restore_data(target_url: &str, input_path: &str) -> Result<()> {
Duration::from_secs(1), // Start with 1 second delay
"pg_restore (restore data)",
)
.await
.context(
"Data restoration failed.\n\
\n\
Expand Down
6 changes: 1 addition & 5 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ pub fn validate_mysql_url(connection_string: &str) -> Result<String> {
}

if !connection_string.starts_with("mysql://") {
bail!(
"Invalid MySQL connection string '{}'. \
Must start with 'mysql://'",
connection_string
);
bail!("Invalid MySQL connection string. Must start with 'mysql://'");
}

tracing::debug!("Validated MySQL connection string");
Expand Down
12 changes: 10 additions & 2 deletions src/mysql/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ pub async fn get_table_row_count(
tracing::debug!("Getting row count for table '{}.{}'", db_name, table_name);

// Use backticks for identifiers to allow reserved words
let query = format!("SELECT COUNT(*) FROM `{}`.`{}`", db_name, table_name);
let query = format!(
"SELECT COUNT(*) FROM {}.{}",
crate::utils::quote_mysql_ident(db_name),
crate::utils::quote_mysql_ident(table_name)
);

let count: Option<u64> = conn
.query_first(&query)
Expand Down Expand Up @@ -137,7 +141,11 @@ pub async fn read_table_data(conn: &mut Conn, db_name: &str, table_name: &str) -
tracing::info!("Reading all rows from table '{}.{}'", db_name, table_name);

// Use backticks for identifiers
let query = format!("SELECT * FROM `{}`.`{}`", db_name, table_name);
let query = format!(
"SELECT * FROM {}.{}",
crate::utils::quote_mysql_ident(db_name),
crate::utils::quote_mysql_ident(table_name)
);

let rows: Vec<Row> = conn
.query(&query)
Expand Down
32 changes: 28 additions & 4 deletions src/postgres/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,28 @@ use crate::utils;
use anyhow::{Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use std::sync::OnceLock;
use std::time::Duration;
use tokio_postgres::Client;

/// Thread-safe storage for TLS configuration set at startup
static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();

/// Initialize the TLS certificate policy (call once at startup)
///
/// This must be called before any database connections are made.
/// It is thread-safe and will only set the value once.
///
/// # Arguments
///
/// * `allow` - If true, accept self-signed/invalid TLS certificates (insecure)
pub fn init_tls_policy(allow: bool) {
let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
if allow {
tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
}
}

/// Add TCP keepalive parameters to a PostgreSQL connection string
///
/// Automatically adds keepalive parameters to prevent idle connection timeouts
Expand Down Expand Up @@ -130,10 +149,15 @@ pub async fn connect(connection_string: &str) -> Result<Client> {
)?;

// Set up TLS connector for cloud connections
// TEMPORARY: Accept invalid certs to debug TLS issues
// TODO: Remove this once we identify the certificate validation issue
let tls_connector = TlsConnector::builder()
.danger_accept_invalid_certs(true)
// By default, require valid certificates. Allow opt-in via init_tls_policy() called at startup.
let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);

let mut tls_builder = TlsConnector::builder();
if allow_self_signed {
tls_builder.danger_accept_invalid_certs(true);
}

let tls_connector = tls_builder
.build()
.context("Failed to build TLS connector")?;
let tls = MakeTlsConnector::new(tls_connector);
Expand Down
2 changes: 2 additions & 0 deletions src/remote/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub struct JobSpec {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilterSpec {
pub include_databases: Option<Vec<String>>,
pub exclude_databases: Option<Vec<String>>,
pub include_tables: Option<Vec<String>>,
pub exclude_tables: Option<Vec<String>>,
}

Expand Down
132 changes: 72 additions & 60 deletions src/replication/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,47 @@ pub async fn get_replication_lag(
client: &Client,
subscription_name: Option<&str>,
) -> Result<Vec<SourceReplicationStats>> {
let query = if let Some(sub_name) = subscription_name {
format!(
"SELECT
application_name,
state,
sent_lsn::text,
write_lsn::text,
flush_lsn::text,
replay_lsn::text,
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
FROM pg_stat_replication
WHERE application_name = '{}'",
sub_name
)
} else {
"SELECT
application_name,
state,
sent_lsn::text,
write_lsn::text,
flush_lsn::text,
replay_lsn::text,
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
FROM pg_stat_replication"
.to_string()
};
if let Some(name) = subscription_name {
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
}

let rows = client
.query(&query, &[])
.await
.context("Failed to query replication statistics")?;
let rows = if let Some(sub_name) = subscription_name {
client
.query(
"SELECT
application_name,
state,
sent_lsn::text,
write_lsn::text,
flush_lsn::text,
replay_lsn::text,
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
FROM pg_stat_replication
WHERE application_name = $1",
&[&sub_name],
)
.await
} else {
client
.query(
"SELECT
application_name,
state,
sent_lsn::text,
write_lsn::text,
flush_lsn::text,
replay_lsn::text,
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
FROM pg_stat_replication",
&[],
)
.await
}
.context("Failed to query replication statistics")?;

let mut stats = Vec::new();
for row in rows {
Expand All @@ -94,33 +100,39 @@ pub async fn get_subscription_status(
client: &Client,
subscription_name: Option<&str>,
) -> Result<Vec<SubscriptionStats>> {
let query = if let Some(sub_name) = subscription_name {
format!(
"SELECT
subname,
pid,
received_lsn::text,
latest_end_lsn::text,
srsubstate
FROM pg_stat_subscription
WHERE subname = '{}'",
sub_name
)
} else {
"SELECT
subname,
pid,
received_lsn::text,
latest_end_lsn::text,
srsubstate
FROM pg_stat_subscription"
.to_string()
};
if let Some(name) = subscription_name {
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
}

let rows = client
.query(&query, &[])
.await
.context("Failed to query subscription statistics")?;
let rows = if let Some(sub_name) = subscription_name {
client
.query(
"SELECT
subname,
pid,
received_lsn::text,
latest_end_lsn::text,
srsubstate
FROM pg_stat_subscription
WHERE subname = $1",
&[&sub_name],
)
.await
} else {
client
.query(
"SELECT
subname,
pid,
received_lsn::text,
latest_end_lsn::text,
srsubstate
FROM pg_stat_subscription",
&[],
)
.await
}
.context("Failed to query subscription statistics")?;

let mut stats = Vec::new();
for row in rows {
Expand Down
Loading