Skip to content

Commit 7ba7c7c

Browse files
fix: harden remote filters and async retries
1 parent 753a26a commit 7ba7c7c

9 files changed

Lines changed: 149 additions & 77 deletions

File tree

src/main.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ use database_replicator::commands;
99
#[command(about = "Universal database-to-PostgreSQL replication CLI", long_about = None)]
1010
#[command(version)]
1111
struct Cli {
12+
/// Allow self-signed TLS certificates. Also honors SEREN_ALLOW_SELF_SIGNED_CERTS=1
13+
#[arg(long = "allow-self-signed-certs", global = true, default_value_t = false)]
14+
allow_self_signed_certs: bool,
1215
#[command(subcommand)]
1316
command: Commands,
1417
}
@@ -181,6 +184,24 @@ async fn main() -> anyhow::Result<()> {
181184

182185
let cli = Cli::parse();
183186

187+
// Honor allow-self-signed flag or env var by setting env consumed in postgres::connection
188+
let allow_self_signed_env = std::env::var("SEREN_ALLOW_SELF_SIGNED_CERTS")
189+
.ok()
190+
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
191+
.unwrap_or(false)
192+
// Backward compatibility with older env name
193+
|| std::env::var("SEREN_ALLOW_INVALID_CERTS")
194+
.ok()
195+
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
196+
.unwrap_or(false);
197+
198+
if cli.allow_self_signed_certs || allow_self_signed_env {
199+
// Set both names for compatibility
200+
std::env::set_var("SEREN_ALLOW_SELF_SIGNED_CERTS", "1");
201+
std::env::set_var("SEREN_ALLOW_INVALID_CERTS", "1");
202+
tracing::warn!("Allowing self-signed/invalid TLS certificates (insecure)");
203+
}
204+
184205
match cli.command {
185206
Commands::Validate {
186207
source,
@@ -402,7 +423,7 @@ async fn init_remote(
402423
drop_existing: bool,
403424
no_sync: bool,
404425
seren_api: String,
405-
_job_timeout: u64,
426+
job_timeout: u64,
406427
) -> anyhow::Result<()> {
407428
use database_replicator::migration;
408429
use database_replicator::postgres;
@@ -466,6 +487,8 @@ async fn init_remote(
466487
} else {
467488
Some(FilterSpec {
468489
include_databases,
490+
exclude_databases,
491+
include_tables,
469492
exclude_tables,
470493
})
471494
};
@@ -481,7 +504,12 @@ async fn init_remote(
481504
"estimated_size_bytes".to_string(),
482505
serde_json::Value::Number(serde_json::Number::from(estimated_size_bytes)),
483506
);
484-
// Note: "yes" and "job_timeout" are client-side only options, not sent to server
507+
// Optional timeout hint for remote orchestrator
508+
options.insert(
509+
"job_timeout_seconds".to_string(),
510+
serde_json::Value::Number(serde_json::Number::from(job_timeout as i64)),
511+
);
512+
// Note: "yes" is client-side only, not sent to server
485513

486514
let job_spec = JobSpec {
487515
version: "1.0".to_string(),

src/migration/dump.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
6565
Duration::from_secs(1), // Start with 1 second delay
6666
"pg_dumpall (dump globals)",
6767
)
68+
.await
6869
.context(
6970
"pg_dumpall failed to dump global objects.\n\
7071
\n\
@@ -172,6 +173,7 @@ pub async fn dump_schema(
172173
Duration::from_secs(1), // Start with 1 second delay
173174
"pg_dump (dump schema)",
174175
)
176+
.await
175177
.with_context(|| {
176178
format!(
177179
"pg_dump failed to dump schema for database '{}'.\n\
@@ -299,6 +301,7 @@ pub async fn dump_data(
299301
Duration::from_secs(1), // Start with 1 second delay
300302
"pg_dump (dump data)",
301303
)
304+
.await
302305
.with_context(|| {
303306
format!(
304307
"pg_dump failed to dump data for database '{}'.\n\

src/migration/restore.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ pub async fn restore_globals(target_url: &str, input_path: &str) -> Result<()> {
6262
3, // Max 3 retries
6363
Duration::from_secs(1), // Start with 1 second delay
6464
"psql (restore globals)",
65-
);
65+
)
66+
.await;
6667

6768
// Handle result - don't fail on warnings for global objects
6869
match result {
@@ -136,6 +137,7 @@ pub async fn restore_schema(target_url: &str, input_path: &str) -> Result<()> {
136137
Duration::from_secs(1), // Start with 1 second delay
137138
"psql (restore schema)",
138139
)
140+
.await
139141
.context(
140142
"Schema restoration failed.\n\
141143
\n\
@@ -228,6 +230,7 @@ pub async fn restore_data(target_url: &str, input_path: &str) -> Result<()> {
228230
Duration::from_secs(1), // Start with 1 second delay
229231
"pg_restore (restore data)",
230232
)
233+
.await
231234
.context(
232235
"Data restoration failed.\n\
233236
\n\

src/mysql/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,7 @@ pub fn validate_mysql_url(connection_string: &str) -> Result<String> {
4444
}
4545

4646
if !connection_string.starts_with("mysql://") {
47-
bail!(
48-
"Invalid MySQL connection string '{}'. \
49-
Must start with 'mysql://'",
50-
connection_string
51-
);
47+
bail!("Invalid MySQL connection string. Must start with 'mysql://'");
5248
}
5349

5450
tracing::debug!("Validated MySQL connection string");

src/postgres/connection.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,24 @@ pub async fn connect(connection_string: &str) -> Result<Client> {
130130
)?;
131131

132132
// Set up TLS connector for cloud connections
133-
// TEMPORARY: Accept invalid certs to debug TLS issues
134-
// TODO: Remove this once we identify the certificate validation issue
135-
let tls_connector = TlsConnector::builder()
136-
.danger_accept_invalid_certs(true)
133+
// By default, require valid certificates. Allow opt-in for self-signed/invalid certs via env.
134+
let allow_self_signed = std::env::var("SEREN_ALLOW_SELF_SIGNED_CERTS")
135+
.ok()
136+
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
137+
.unwrap_or(false)
138+
// Backward compatibility with older env name
139+
|| std::env::var("SEREN_ALLOW_INVALID_CERTS")
140+
.ok()
141+
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
142+
.unwrap_or(false);
143+
144+
let mut tls_builder = TlsConnector::builder();
145+
if allow_self_signed {
146+
tracing::warn!("Accepting self-signed/invalid TLS certificates");
147+
tls_builder.danger_accept_invalid_certs(true);
148+
}
149+
150+
let tls_connector = tls_builder
137151
.build()
138152
.context("Failed to build TLS connector")?;
139153
let tls = MakeTlsConnector::new(tls_connector);

src/remote/models.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub struct JobSpec {
1919
#[derive(Debug, Clone, Serialize, Deserialize)]
2020
pub struct FilterSpec {
2121
pub include_databases: Option<Vec<String>>,
22+
pub exclude_databases: Option<Vec<String>>,
23+
pub include_tables: Option<Vec<String>>,
2224
pub exclude_tables: Option<Vec<String>>,
2325
}
2426

src/replication/monitor.rs

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,47 @@ pub async fn get_replication_lag(
3434
client: &Client,
3535
subscription_name: Option<&str>,
3636
) -> Result<Vec<SourceReplicationStats>> {
37-
let query = if let Some(sub_name) = subscription_name {
38-
format!(
39-
"SELECT
40-
application_name,
41-
state,
42-
sent_lsn::text,
43-
write_lsn::text,
44-
flush_lsn::text,
45-
replay_lsn::text,
46-
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
47-
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
48-
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
49-
FROM pg_stat_replication
50-
WHERE application_name = '{}'",
51-
sub_name
52-
)
53-
} else {
54-
"SELECT
55-
application_name,
56-
state,
57-
sent_lsn::text,
58-
write_lsn::text,
59-
flush_lsn::text,
60-
replay_lsn::text,
61-
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
62-
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
63-
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
64-
FROM pg_stat_replication"
65-
.to_string()
66-
};
37+
if let Some(name) = subscription_name {
38+
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
39+
}
6740

68-
let rows = client
69-
.query(&query, &[])
70-
.await
71-
.context("Failed to query replication statistics")?;
41+
let rows = if let Some(sub_name) = subscription_name {
42+
client
43+
.query(
44+
"SELECT
45+
application_name,
46+
state,
47+
sent_lsn::text,
48+
write_lsn::text,
49+
flush_lsn::text,
50+
replay_lsn::text,
51+
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
52+
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
53+
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
54+
FROM pg_stat_replication
55+
WHERE application_name = $1",
56+
&[&sub_name],
57+
)
58+
.await
59+
} else {
60+
client
61+
.query(
62+
"SELECT
63+
application_name,
64+
state,
65+
sent_lsn::text,
66+
write_lsn::text,
67+
flush_lsn::text,
68+
replay_lsn::text,
69+
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
70+
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
71+
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
72+
FROM pg_stat_replication",
73+
&[],
74+
)
75+
.await
76+
}
77+
.context("Failed to query replication statistics")?;
7278

7379
let mut stats = Vec::new();
7480
for row in rows {
@@ -94,33 +100,39 @@ pub async fn get_subscription_status(
94100
client: &Client,
95101
subscription_name: Option<&str>,
96102
) -> Result<Vec<SubscriptionStats>> {
97-
let query = if let Some(sub_name) = subscription_name {
98-
format!(
99-
"SELECT
100-
subname,
101-
pid,
102-
received_lsn::text,
103-
latest_end_lsn::text,
104-
srsubstate
105-
FROM pg_stat_subscription
106-
WHERE subname = '{}'",
107-
sub_name
108-
)
109-
} else {
110-
"SELECT
111-
subname,
112-
pid,
113-
received_lsn::text,
114-
latest_end_lsn::text,
115-
srsubstate
116-
FROM pg_stat_subscription"
117-
.to_string()
118-
};
103+
if let Some(name) = subscription_name {
104+
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
105+
}
119106

120-
let rows = client
121-
.query(&query, &[])
122-
.await
123-
.context("Failed to query subscription statistics")?;
107+
let rows = if let Some(sub_name) = subscription_name {
108+
client
109+
.query(
110+
"SELECT
111+
subname,
112+
pid,
113+
received_lsn::text,
114+
latest_end_lsn::text,
115+
srsubstate
116+
FROM pg_stat_subscription
117+
WHERE subname = $1",
118+
&[&sub_name],
119+
)
120+
.await
121+
} else {
122+
client
123+
.query(
124+
"SELECT
125+
subname,
126+
pid,
127+
received_lsn::text,
128+
latest_end_lsn::text,
129+
srsubstate
130+
FROM pg_stat_subscription",
131+
&[],
132+
)
133+
.await
134+
}
135+
.context("Failed to query subscription statistics")?;
124136

125137
let mut stats = Vec::new();
126138
for row in rows {

src/replication/subscription.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ pub async fn create_subscription(
5050
" To avoid storing passwords, configure .pgpass on the target PostgreSQL server"
5151
);
5252

53+
// Escape single quotes in connection string to avoid breaking SQL literal
54+
let escaped_conn = source_connection_string.replace('\'', "''");
55+
5356
let query = format!(
5457
"CREATE SUBSCRIPTION \"{}\" CONNECTION '{}' PUBLICATION \"{}\"",
55-
subscription_name, source_connection_string, publication_name
58+
subscription_name, escaped_conn, publication_name
5659
);
5760

5861
match client.execute(&query, &[]).await {

src/utils.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ where
271271
/// 3, // Try up to 3 times
272272
/// Duration::from_secs(1), // Start with 1s delay
273273
/// "psql"
274-
/// )?;
274+
/// ).await?;
275275
/// # Ok(())
276276
/// # }
277277
/// ```
278-
pub fn retry_subprocess_with_backoff<F>(
278+
pub async fn retry_subprocess_with_backoff<F>(
279279
mut operation: F,
280280
max_retries: u32,
281281
initial_delay: Duration,
@@ -328,7 +328,7 @@ where
328328
last_error.as_ref().unwrap(),
329329
delay
330330
);
331-
std::thread::sleep(delay);
331+
tokio::time::sleep(delay).await;
332332
delay *= 2; // Exponential backoff
333333
}
334334
}
@@ -561,7 +561,18 @@ pub fn validate_source_target_different(source_url: &str, target_url: &str) -> R
561561
&& source_parts.user == target_parts.user
562562
{
563563
bail!(
564-
"Source and target URLs point to the same database!\\n\\\n \\n\\\n This would cause DATA LOSS - the target would overwrite the source.\\n\\\n \\n\\\n Source: {}@{}:{}/{}\\n\\\n Target: {}@{}:{}/{}\\n\\\n \\n\\\n Please ensure source and target are different databases.\\n\\\n Common causes:\\n\\\n - Copy-paste error in connection strings\\n\\\n - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\\n\\\n - Typo in database name or host",
564+
"Source and target URLs point to the same database!\n\
565+
\n\
566+
This would cause DATA LOSS - the target would overwrite the source.\n\
567+
\n\
568+
Source: {}@{}:{}/{}\n\
569+
Target: {}@{}:{}/{}\n\
570+
\n\
571+
Please ensure source and target are different databases.\n\
572+
Common causes:\n\
573+
- Copy-paste error in connection strings\n\
574+
- Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
575+
- Typo in database name or host",
565576
source_parts.user.as_deref().unwrap_or("(no user)"),
566577
source_parts.host,
567578
source_parts.port,

0 commit comments

Comments
 (0)