Skip to content

Commit f497a91

Browse files
fix: consistent SQL quoting and thread-safe TLS config
- Add quote_literal() for SQL string escaping - Add quote_mysql_ident() for MySQL backtick quoting - Use quote_ident() consistently for PostgreSQL/SQLite identifiers - Replace unsafe std::env::set_var with thread-safe OnceLock - Fix blocking std::thread::sleep in async context - Update security test to verify credential leak fix
1 parent 7ba7c7c commit f497a91

10 files changed

Lines changed: 137 additions & 63 deletions

File tree

src/commands/init.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ pub async fn init(
330330
.with_context(|| format!("Invalid database name: '{}'", db_info.name))?;
331331

332332
// Try to create database atomically (avoids TOCTOU vulnerability)
333-
let create_query = format!("CREATE DATABASE \"{}\"", db_info.name);
333+
let create_query = format!(
334+
"CREATE DATABASE {}",
335+
crate::utils::quote_ident(&db_info.name)
336+
);
334337
match target_client.execute(&create_query, &[]).await {
335338
Ok(_) => {
336339
tracing::info!(" Created database '{}'", db_info.name);
@@ -372,8 +375,10 @@ pub async fn init(
372375
drop_database_if_exists(&target_client, &db_info.name).await?;
373376

374377
// Recreate the database
375-
let create_query =
376-
format!("CREATE DATABASE \"{}\"", db_info.name);
378+
let create_query = format!(
379+
"CREATE DATABASE {}",
380+
crate::utils::quote_ident(&db_info.name)
381+
);
377382
target_client
378383
.execute(&create_query, &[])
379384
.await
@@ -666,7 +671,10 @@ async fn drop_database_if_exists(target_conn: &Client, db_name: &str) -> Result<
666671
target_conn.execute(terminate_query, &[&db_name]).await?;
667672

668673
// Drop the database
669-
let drop_query = format!("DROP DATABASE IF EXISTS \"{}\"", db_name);
674+
let drop_query = format!(
675+
"DROP DATABASE IF EXISTS {}",
676+
crate::utils::quote_ident(db_name)
677+
);
670678
target_conn
671679
.execute(&drop_query, &[])
672680
.await

src/main.rs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ use database_replicator::commands;
99
#[command(about = "Universal database-to-PostgreSQL replication CLI", long_about = None)]
1010
#[command(version)]
1111
struct Cli {
12-
/// Allow self-signed TLS certificates. Also honors SEREN_ALLOW_SELF_SIGNED_CERTS=1
13-
#[arg(long = "allow-self-signed-certs", global = true, default_value_t = false)]
12+
/// Allow self-signed TLS certificates (insecure - use only for testing)
13+
#[arg(
14+
long = "allow-self-signed-certs",
15+
global = true,
16+
default_value_t = false
17+
)]
1418
allow_self_signed_certs: bool,
1519
#[command(subcommand)]
1620
command: Commands,
@@ -184,23 +188,8 @@ async fn main() -> anyhow::Result<()> {
184188

185189
let cli = Cli::parse();
186190

187-
// Honor allow-self-signed flag or env var by setting env consumed in postgres::connection
188-
let allow_self_signed_env = std::env::var("SEREN_ALLOW_SELF_SIGNED_CERTS")
189-
.ok()
190-
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
191-
.unwrap_or(false)
192-
// Backward compatibility with older env name
193-
|| std::env::var("SEREN_ALLOW_INVALID_CERTS")
194-
.ok()
195-
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
196-
.unwrap_or(false);
197-
198-
if cli.allow_self_signed_certs || allow_self_signed_env {
199-
// Set both names for compatibility
200-
std::env::set_var("SEREN_ALLOW_SELF_SIGNED_CERTS", "1");
201-
std::env::set_var("SEREN_ALLOW_INVALID_CERTS", "1");
202-
tracing::warn!("Allowing self-signed/invalid TLS certificates (insecure)");
203-
}
191+
// Initialize TLS policy using thread-safe OnceLock
192+
database_replicator::postgres::connection::init_tls_policy(cli.allow_self_signed_certs);
204193

205194
match cli.command {
206195
Commands::Validate {

src/mysql/reader.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ pub async fn get_table_row_count(
8686
tracing::debug!("Getting row count for table '{}.{}'", db_name, table_name);
8787

8888
// Use backticks for identifiers to allow reserved words
89-
let query = format!("SELECT COUNT(*) FROM `{}`.`{}`", db_name, table_name);
89+
let query = format!(
90+
"SELECT COUNT(*) FROM {}.{}",
91+
crate::utils::quote_mysql_ident(db_name),
92+
crate::utils::quote_mysql_ident(table_name)
93+
);
9094

9195
let count: Option<u64> = conn
9296
.query_first(&query)
@@ -137,7 +141,11 @@ pub async fn read_table_data(conn: &mut Conn, db_name: &str, table_name: &str) -
137141
tracing::info!("Reading all rows from table '{}.{}'", db_name, table_name);
138142

139143
// Use backticks for identifiers
140-
let query = format!("SELECT * FROM `{}`.`{}`", db_name, table_name);
144+
let query = format!(
145+
"SELECT * FROM {}.{}",
146+
crate::utils::quote_mysql_ident(db_name),
147+
crate::utils::quote_mysql_ident(table_name)
148+
);
141149

142150
let rows: Vec<Row> = conn
143151
.query(&query)

src/postgres/connection.rs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@ use crate::utils;
55
use anyhow::{Context, Result};
66
use native_tls::TlsConnector;
77
use postgres_native_tls::MakeTlsConnector;
8+
use std::sync::OnceLock;
89
use std::time::Duration;
910
use tokio_postgres::Client;
1011

12+
/// Thread-safe storage for TLS configuration set at startup
13+
static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();
14+
15+
/// Initialize the TLS certificate policy (call once at startup)
16+
///
17+
/// This must be called before any database connections are made.
18+
/// It is thread-safe and will only set the value once.
19+
///
20+
/// # Arguments
21+
///
22+
/// * `allow` - If true, accept self-signed/invalid TLS certificates (insecure)
23+
pub fn init_tls_policy(allow: bool) {
24+
let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
25+
if allow {
26+
tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
27+
}
28+
}
29+
1130
/// Add TCP keepalive parameters to a PostgreSQL connection string
1231
///
1332
/// Automatically adds keepalive parameters to prevent idle connection timeouts
@@ -130,20 +149,11 @@ pub async fn connect(connection_string: &str) -> Result<Client> {
130149
)?;
131150

132151
// Set up TLS connector for cloud connections
133-
// By default, require valid certificates. Allow opt-in for self-signed/invalid certs via env.
134-
let allow_self_signed = std::env::var("SEREN_ALLOW_SELF_SIGNED_CERTS")
135-
.ok()
136-
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
137-
.unwrap_or(false)
138-
// Backward compatibility with older env name
139-
|| std::env::var("SEREN_ALLOW_INVALID_CERTS")
140-
.ok()
141-
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
142-
.unwrap_or(false);
152+
// By default, require valid certificates. Allow opt-in via init_tls_policy() called at startup.
153+
let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);
143154

144155
let mut tls_builder = TlsConnector::builder();
145156
if allow_self_signed {
146-
tracing::warn!("Accepting self-signed/invalid TLS certificates");
147157
tls_builder.danger_accept_invalid_certs(true);
148158
}
149159

src/replication/publication.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ pub async fn create_publication(
3939
tracing::info!("Creating publication '{}'...", publication_name);
4040

4141
if filter.is_empty() {
42-
let query = format!("CREATE PUBLICATION \"{}\" FOR ALL TABLES", publication_name);
42+
let query = format!(
43+
"CREATE PUBLICATION {} FOR ALL TABLES",
44+
crate::utils::quote_ident(publication_name)
45+
);
4346
return execute_publication_query(client, publication_name, &query).await;
4447
}
4548

@@ -121,8 +124,8 @@ pub async fn create_publication(
121124
);
122125

123126
let query = format!(
124-
"CREATE PUBLICATION \"{}\" FOR TABLE {}",
125-
publication_name,
127+
"CREATE PUBLICATION {} FOR TABLE {}",
128+
crate::utils::quote_ident(publication_name),
126129
clauses.join(", ")
127130
);
128131

@@ -218,7 +221,10 @@ pub async fn drop_publication(client: &Client, publication_name: &str) -> Result
218221

219222
tracing::info!("Dropping publication '{}'...", publication_name);
220223

221-
let query = format!("DROP PUBLICATION IF EXISTS \"{}\"", publication_name);
224+
let query = format!(
225+
"DROP PUBLICATION IF EXISTS {}",
226+
crate::utils::quote_ident(publication_name)
227+
);
222228

223229
client
224230
.execute(&query, &[])

src/replication/subscription.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,11 @@ pub async fn create_subscription(
5050
" To avoid storing passwords, configure .pgpass on the target PostgreSQL server"
5151
);
5252

53-
// Escape single quotes in connection string to avoid breaking SQL literal
54-
let escaped_conn = source_connection_string.replace('\'', "''");
55-
5653
let query = format!(
57-
"CREATE SUBSCRIPTION \"{}\" CONNECTION '{}' PUBLICATION \"{}\"",
58-
subscription_name, escaped_conn, publication_name
54+
"CREATE SUBSCRIPTION {} CONNECTION {} PUBLICATION {}",
55+
crate::utils::quote_ident(subscription_name),
56+
crate::utils::quote_literal(source_connection_string),
57+
crate::utils::quote_ident(publication_name)
5958
);
6059

6160
match client.execute(&query, &[]).await {
@@ -158,7 +157,10 @@ pub async fn drop_subscription(client: &Client, subscription_name: &str) -> Resu
158157

159158
tracing::info!("Dropping subscription '{}'...", subscription_name);
160159

161-
let query = format!("DROP SUBSCRIPTION IF EXISTS \"{}\"", subscription_name);
160+
let query = format!(
161+
"DROP SUBSCRIPTION IF EXISTS {}",
162+
crate::utils::quote_ident(subscription_name)
163+
);
162164

163165
client.execute(&query, &[]).await.context(format!(
164166
"Failed to drop subscription '{}'",

src/sqlite/reader.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ pub fn get_table_row_count(conn: &Connection, table: &str) -> Result<usize> {
9595
tracing::debug!("Getting row count for table '{}'", table);
9696

9797
// Note: table name is validated above, so it's safe to use in SQL
98-
let query = format!("SELECT COUNT(*) FROM \"{}\"", table);
98+
let query = format!("SELECT COUNT(*) FROM {}", crate::utils::quote_ident(table));
9999

100100
let count: i64 = conn
101101
.query_row(&query, [], |row| row.get(0))
@@ -151,7 +151,7 @@ pub fn read_table_data(
151151
tracing::info!("Reading all data from table '{}'", table);
152152

153153
// Note: table name is validated above
154-
let query = format!("SELECT * FROM \"{}\"", table);
154+
let query = format!("SELECT * FROM {}", crate::utils::quote_ident(table));
155155

156156
let mut stmt = conn
157157
.prepare(&query)

src/utils.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ where
261261
/// # use std::time::Duration;
262262
/// # use std::process::Command;
263263
/// # use database_replicator::utils::retry_subprocess_with_backoff;
264-
/// # fn example() -> Result<()> {
264+
/// # async fn example() -> Result<()> {
265265
/// retry_subprocess_with_backoff(
266266
/// || {
267267
/// let mut cmd = Command::new("psql");
@@ -311,7 +311,7 @@ where
311311
max_retries + 1,
312312
delay
313313
);
314-
std::thread::sleep(delay);
314+
tokio::time::sleep(delay).await;
315315
delay *= 2; // Exponential backoff
316316
}
317317
}
@@ -490,6 +490,57 @@ pub fn quote_ident(identifier: &str) -> String {
490490
quoted
491491
}
492492

493+
/// Quote a SQL string literal (for use in SQL statements)
494+
///
495+
/// Escapes single quotes by doubling them and wraps the string in single quotes.
496+
/// Use this for string values in SQL, not for identifiers.
497+
///
498+
/// # Examples
499+
///
500+
/// ```
501+
/// use database_replicator::utils::quote_literal;
502+
/// assert_eq!(quote_literal("hello"), "'hello'");
503+
/// assert_eq!(quote_literal("it's"), "'it''s'");
504+
/// assert_eq!(quote_literal(""), "''");
505+
/// ```
506+
pub fn quote_literal(value: &str) -> String {
507+
let mut quoted = String::with_capacity(value.len() + 2);
508+
quoted.push('\'');
509+
for ch in value.chars() {
510+
if ch == '\'' {
511+
quoted.push('\'');
512+
}
513+
quoted.push(ch);
514+
}
515+
quoted.push('\'');
516+
quoted
517+
}
518+
519+
/// Quote a MySQL identifier (database, table, column)
520+
///
521+
/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
522+
/// by doubling them.
523+
///
524+
/// # Examples
525+
///
526+
/// ```
527+
/// use database_replicator::utils::quote_mysql_ident;
528+
/// assert_eq!(quote_mysql_ident("users"), "`users`");
529+
/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
530+
/// ```
531+
pub fn quote_mysql_ident(identifier: &str) -> String {
532+
let mut quoted = String::with_capacity(identifier.len() + 2);
533+
quoted.push('`');
534+
for ch in identifier.chars() {
535+
if ch == '`' {
536+
quoted.push('`');
537+
}
538+
quoted.push(ch);
539+
}
540+
quoted.push('`');
541+
quoted
542+
}
543+
493544
/// Validate that source and target URLs are different to prevent accidental data loss
494545
///
495546
/// Compares two PostgreSQL connection URLs to ensure they point to different databases.

tests/integration_remote_test.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ async fn test_remote_job_submission_with_filters() {
266266
// Create a job spec with database filters
267267
let filter = database_replicator::remote::FilterSpec {
268268
include_databases: Some(vec!["postgres".to_string()]),
269+
exclude_databases: None,
270+
include_tables: None,
269271
exclude_tables: None,
270272
};
271273

tests/security_test.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -787,33 +787,31 @@ fn test_mysql_url_with_special_chars_in_password() {
787787
fn test_mysql_error_messages_dont_leak_credentials() {
788788
use database_replicator::mysql;
789789

790-
// SECURITY NOTE: Current implementation includes full URL in error messages
791-
// This test documents the current behavior - ideally this should be fixed
792-
// to sanitize URLs before including in error messages
793-
790+
// SECURITY: Error messages should NOT leak passwords or full URLs
794791
let url_with_password = "not-mysql://admin:secretpass@host:3306/db";
795792

796793
let result = mysql::validate_mysql_url(url_with_password);
797794
assert!(result.is_err(), "Invalid URL should be rejected");
798795

799796
let error_msg = result.unwrap_err().to_string();
800797

801-
// KNOWN ISSUE: Error message currently contains the full URL including password
802-
// This test verifies current behavior, but this should be improved
798+
// Verify password is NOT leaked in error message
803799
assert!(
804-
error_msg.contains("secretpass") || error_msg.contains("not-mysql://"),
805-
"Error message currently includes full URL (known issue)"
800+
!error_msg.contains("secretpass"),
801+
"Error message should not contain password: {error_msg}"
802+
);
803+
804+
// Verify full URL is NOT leaked in error message
805+
assert!(
806+
!error_msg.contains("not-mysql://"),
807+
"Error message should not contain full malformed URL: {error_msg}"
806808
);
807809

808810
// Verify it does explain the validation failure
809811
assert!(
810812
error_msg.contains("mysql://") || error_msg.contains("Invalid"),
811-
"Error should explain validation requirement"
813+
"Error should explain validation requirement: {error_msg}"
812814
);
813-
814-
// TODO: Enhance validate_mysql_url to sanitize URLs in error messages
815-
// Expected: "Invalid MySQL connection string. Must start with 'mysql://'"
816-
// (without exposing the actual malformed URL)
817815
}
818816

819817
// ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)