Skip to content

Commit aaefa5b

Browse files
fix: harden SQL quoting, TLS handling, and async operations (#1)
* fix: harden remote filters and async retries * fix: consistent SQL quoting and thread-safe TLS config - Add quote_literal() for SQL string escaping - Add quote_mysql_ident() for MySQL backtick quoting - Use quote_ident() consistently for PostgreSQL/SQLite identifiers - Replace unsafe std::env::set_var with thread-safe OnceLock - Fix blocking std::thread::sleep in async context - Update security test to verify credential leak fix
1 parent 753a26a commit aaefa5b

15 files changed

Lines changed: 252 additions & 106 deletions

File tree

src/commands/init.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ pub async fn init(
330330
.with_context(|| format!("Invalid database name: '{}'", db_info.name))?;
331331

332332
// Try to create database atomically (avoids TOCTOU vulnerability)
333-
let create_query = format!("CREATE DATABASE \"{}\"", db_info.name);
333+
let create_query = format!(
334+
"CREATE DATABASE {}",
335+
crate::utils::quote_ident(&db_info.name)
336+
);
334337
match target_client.execute(&create_query, &[]).await {
335338
Ok(_) => {
336339
tracing::info!(" Created database '{}'", db_info.name);
@@ -372,8 +375,10 @@ pub async fn init(
372375
drop_database_if_exists(&target_client, &db_info.name).await?;
373376

374377
// Recreate the database
375-
let create_query =
376-
format!("CREATE DATABASE \"{}\"", db_info.name);
378+
let create_query = format!(
379+
"CREATE DATABASE {}",
380+
crate::utils::quote_ident(&db_info.name)
381+
);
377382
target_client
378383
.execute(&create_query, &[])
379384
.await
@@ -666,7 +671,10 @@ async fn drop_database_if_exists(target_conn: &Client, db_name: &str) -> Result<
666671
target_conn.execute(terminate_query, &[&db_name]).await?;
667672

668673
// Drop the database
669-
let drop_query = format!("DROP DATABASE IF EXISTS \"{}\"", db_name);
674+
let drop_query = format!(
675+
"DROP DATABASE IF EXISTS {}",
676+
crate::utils::quote_ident(db_name)
677+
);
670678
target_conn
671679
.execute(&drop_query, &[])
672680
.await

src/main.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ use database_replicator::commands;
99
#[command(about = "Universal database-to-PostgreSQL replication CLI", long_about = None)]
1010
#[command(version)]
1111
struct Cli {
12+
/// Allow self-signed TLS certificates (insecure - use only for testing)
13+
#[arg(
14+
long = "allow-self-signed-certs",
15+
global = true,
16+
default_value_t = false
17+
)]
18+
allow_self_signed_certs: bool,
1219
#[command(subcommand)]
1320
command: Commands,
1421
}
@@ -181,6 +188,9 @@ async fn main() -> anyhow::Result<()> {
181188

182189
let cli = Cli::parse();
183190

191+
// Initialize TLS policy using thread-safe OnceLock
192+
database_replicator::postgres::connection::init_tls_policy(cli.allow_self_signed_certs);
193+
184194
match cli.command {
185195
Commands::Validate {
186196
source,
@@ -402,7 +412,7 @@ async fn init_remote(
402412
drop_existing: bool,
403413
no_sync: bool,
404414
seren_api: String,
405-
_job_timeout: u64,
415+
job_timeout: u64,
406416
) -> anyhow::Result<()> {
407417
use database_replicator::migration;
408418
use database_replicator::postgres;
@@ -466,6 +476,8 @@ async fn init_remote(
466476
} else {
467477
Some(FilterSpec {
468478
include_databases,
479+
exclude_databases,
480+
include_tables,
469481
exclude_tables,
470482
})
471483
};
@@ -481,7 +493,12 @@ async fn init_remote(
481493
"estimated_size_bytes".to_string(),
482494
serde_json::Value::Number(serde_json::Number::from(estimated_size_bytes)),
483495
);
484-
// Note: "yes" and "job_timeout" are client-side only options, not sent to server
496+
// Optional timeout hint for remote orchestrator
497+
options.insert(
498+
"job_timeout_seconds".to_string(),
499+
serde_json::Value::Number(serde_json::Number::from(job_timeout as i64)),
500+
);
501+
// Note: "yes" is client-side only, not sent to server
485502

486503
let job_spec = JobSpec {
487504
version: "1.0".to_string(),

src/migration/dump.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
6565
Duration::from_secs(1), // Start with 1 second delay
6666
"pg_dumpall (dump globals)",
6767
)
68+
.await
6869
.context(
6970
"pg_dumpall failed to dump global objects.\n\
7071
\n\
@@ -172,6 +173,7 @@ pub async fn dump_schema(
172173
Duration::from_secs(1), // Start with 1 second delay
173174
"pg_dump (dump schema)",
174175
)
176+
.await
175177
.with_context(|| {
176178
format!(
177179
"pg_dump failed to dump schema for database '{}'.\n\
@@ -299,6 +301,7 @@ pub async fn dump_data(
299301
Duration::from_secs(1), // Start with 1 second delay
300302
"pg_dump (dump data)",
301303
)
304+
.await
302305
.with_context(|| {
303306
format!(
304307
"pg_dump failed to dump data for database '{}'.\n\

src/migration/restore.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ pub async fn restore_globals(target_url: &str, input_path: &str) -> Result<()> {
6262
3, // Max 3 retries
6363
Duration::from_secs(1), // Start with 1 second delay
6464
"psql (restore globals)",
65-
);
65+
)
66+
.await;
6667

6768
// Handle result - don't fail on warnings for global objects
6869
match result {
@@ -136,6 +137,7 @@ pub async fn restore_schema(target_url: &str, input_path: &str) -> Result<()> {
136137
Duration::from_secs(1), // Start with 1 second delay
137138
"psql (restore schema)",
138139
)
140+
.await
139141
.context(
140142
"Schema restoration failed.\n\
141143
\n\
@@ -228,6 +230,7 @@ pub async fn restore_data(target_url: &str, input_path: &str) -> Result<()> {
228230
Duration::from_secs(1), // Start with 1 second delay
229231
"pg_restore (restore data)",
230232
)
233+
.await
231234
.context(
232235
"Data restoration failed.\n\
233236
\n\

src/mysql/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,7 @@ pub fn validate_mysql_url(connection_string: &str) -> Result<String> {
4444
}
4545

4646
if !connection_string.starts_with("mysql://") {
47-
bail!(
48-
"Invalid MySQL connection string '{}'. \
49-
Must start with 'mysql://'",
50-
connection_string
51-
);
47+
bail!("Invalid MySQL connection string. Must start with 'mysql://'");
5248
}
5349

5450
tracing::debug!("Validated MySQL connection string");

src/mysql/reader.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ pub async fn get_table_row_count(
8686
tracing::debug!("Getting row count for table '{}.{}'", db_name, table_name);
8787

8888
// Use backticks for identifiers to allow reserved words
89-
let query = format!("SELECT COUNT(*) FROM `{}`.`{}`", db_name, table_name);
89+
let query = format!(
90+
"SELECT COUNT(*) FROM {}.{}",
91+
crate::utils::quote_mysql_ident(db_name),
92+
crate::utils::quote_mysql_ident(table_name)
93+
);
9094

9195
let count: Option<u64> = conn
9296
.query_first(&query)
@@ -137,7 +141,11 @@ pub async fn read_table_data(conn: &mut Conn, db_name: &str, table_name: &str) -
137141
tracing::info!("Reading all rows from table '{}.{}'", db_name, table_name);
138142

139143
// Use backticks for identifiers
140-
let query = format!("SELECT * FROM `{}`.`{}`", db_name, table_name);
144+
let query = format!(
145+
"SELECT * FROM {}.{}",
146+
crate::utils::quote_mysql_ident(db_name),
147+
crate::utils::quote_mysql_ident(table_name)
148+
);
141149

142150
let rows: Vec<Row> = conn
143151
.query(&query)

src/postgres/connection.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@ use crate::utils;
55
use anyhow::{Context, Result};
66
use native_tls::TlsConnector;
77
use postgres_native_tls::MakeTlsConnector;
8+
use std::sync::OnceLock;
89
use std::time::Duration;
910
use tokio_postgres::Client;
1011

12+
/// Thread-safe storage for TLS configuration set at startup
13+
static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();
14+
15+
/// Initialize the TLS certificate policy (call once at startup)
16+
///
17+
/// This must be called before any database connections are made.
18+
/// It is thread-safe and will only set the value once.
19+
///
20+
/// # Arguments
21+
///
22+
/// * `allow` - If true, accept self-signed/invalid TLS certificates (insecure)
23+
pub fn init_tls_policy(allow: bool) {
24+
let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
25+
if allow {
26+
tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
27+
}
28+
}
29+
1130
/// Add TCP keepalive parameters to a PostgreSQL connection string
1231
///
1332
/// Automatically adds keepalive parameters to prevent idle connection timeouts
@@ -130,10 +149,15 @@ pub async fn connect(connection_string: &str) -> Result<Client> {
130149
)?;
131150

132151
// Set up TLS connector for cloud connections
133-
// TEMPORARY: Accept invalid certs to debug TLS issues
134-
// TODO: Remove this once we identify the certificate validation issue
135-
let tls_connector = TlsConnector::builder()
136-
.danger_accept_invalid_certs(true)
152+
// By default, require valid certificates. Allow opt-in via init_tls_policy() called at startup.
153+
let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);
154+
155+
let mut tls_builder = TlsConnector::builder();
156+
if allow_self_signed {
157+
tls_builder.danger_accept_invalid_certs(true);
158+
}
159+
160+
let tls_connector = tls_builder
137161
.build()
138162
.context("Failed to build TLS connector")?;
139163
let tls = MakeTlsConnector::new(tls_connector);

src/remote/models.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub struct JobSpec {
1919
#[derive(Debug, Clone, Serialize, Deserialize)]
2020
pub struct FilterSpec {
2121
pub include_databases: Option<Vec<String>>,
22+
pub exclude_databases: Option<Vec<String>>,
23+
pub include_tables: Option<Vec<String>>,
2224
pub exclude_tables: Option<Vec<String>>,
2325
}
2426

src/replication/monitor.rs

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,47 @@ pub async fn get_replication_lag(
3434
client: &Client,
3535
subscription_name: Option<&str>,
3636
) -> Result<Vec<SourceReplicationStats>> {
37-
let query = if let Some(sub_name) = subscription_name {
38-
format!(
39-
"SELECT
40-
application_name,
41-
state,
42-
sent_lsn::text,
43-
write_lsn::text,
44-
flush_lsn::text,
45-
replay_lsn::text,
46-
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
47-
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
48-
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
49-
FROM pg_stat_replication
50-
WHERE application_name = '{}'",
51-
sub_name
52-
)
53-
} else {
54-
"SELECT
55-
application_name,
56-
state,
57-
sent_lsn::text,
58-
write_lsn::text,
59-
flush_lsn::text,
60-
replay_lsn::text,
61-
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
62-
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
63-
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
64-
FROM pg_stat_replication"
65-
.to_string()
66-
};
37+
if let Some(name) = subscription_name {
38+
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
39+
}
6740

68-
let rows = client
69-
.query(&query, &[])
70-
.await
71-
.context("Failed to query replication statistics")?;
41+
let rows = if let Some(sub_name) = subscription_name {
42+
client
43+
.query(
44+
"SELECT
45+
application_name,
46+
state,
47+
sent_lsn::text,
48+
write_lsn::text,
49+
flush_lsn::text,
50+
replay_lsn::text,
51+
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
52+
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
53+
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
54+
FROM pg_stat_replication
55+
WHERE application_name = $1",
56+
&[&sub_name],
57+
)
58+
.await
59+
} else {
60+
client
61+
.query(
62+
"SELECT
63+
application_name,
64+
state,
65+
sent_lsn::text,
66+
write_lsn::text,
67+
flush_lsn::text,
68+
replay_lsn::text,
69+
EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
70+
EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
71+
EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
72+
FROM pg_stat_replication",
73+
&[],
74+
)
75+
.await
76+
}
77+
.context("Failed to query replication statistics")?;
7278

7379
let mut stats = Vec::new();
7480
for row in rows {
@@ -94,33 +100,39 @@ pub async fn get_subscription_status(
94100
client: &Client,
95101
subscription_name: Option<&str>,
96102
) -> Result<Vec<SubscriptionStats>> {
97-
let query = if let Some(sub_name) = subscription_name {
98-
format!(
99-
"SELECT
100-
subname,
101-
pid,
102-
received_lsn::text,
103-
latest_end_lsn::text,
104-
srsubstate
105-
FROM pg_stat_subscription
106-
WHERE subname = '{}'",
107-
sub_name
108-
)
109-
} else {
110-
"SELECT
111-
subname,
112-
pid,
113-
received_lsn::text,
114-
latest_end_lsn::text,
115-
srsubstate
116-
FROM pg_stat_subscription"
117-
.to_string()
118-
};
103+
if let Some(name) = subscription_name {
104+
crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?;
105+
}
119106

120-
let rows = client
121-
.query(&query, &[])
122-
.await
123-
.context("Failed to query subscription statistics")?;
107+
let rows = if let Some(sub_name) = subscription_name {
108+
client
109+
.query(
110+
"SELECT
111+
subname,
112+
pid,
113+
received_lsn::text,
114+
latest_end_lsn::text,
115+
srsubstate
116+
FROM pg_stat_subscription
117+
WHERE subname = $1",
118+
&[&sub_name],
119+
)
120+
.await
121+
} else {
122+
client
123+
.query(
124+
"SELECT
125+
subname,
126+
pid,
127+
received_lsn::text,
128+
latest_end_lsn::text,
129+
srsubstate
130+
FROM pg_stat_subscription",
131+
&[],
132+
)
133+
.await
134+
}
135+
.context("Failed to query subscription statistics")?;
124136

125137
let mut stats = Vec::new();
126138
for row in rows {

0 commit comments

Comments
 (0)