diff --git a/contrib/sync_db_pools/lib/Cargo.toml b/contrib/sync_db_pools/lib/Cargo.toml index b563e48dca..89c03392b8 100644 --- a/contrib/sync_db_pools/lib/Cargo.toml +++ b/contrib/sync_db_pools/lib/Cargo.toml @@ -16,6 +16,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"] diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"] sqlite_pool = ["rusqlite", "r2d2_sqlite"] postgres_pool = ["postgres", "r2d2_postgres"] +postgres_pool_tls = ["postgres_pool", "dep:postgres-native-tls", "dep:native-tls"] memcache_pool = ["memcache", "r2d2-memcache"] [dependencies] @@ -27,6 +28,8 @@ diesel = { version = "2.0.0", default-features = false, optional = true } postgres = { version = "0.19", optional = true } r2d2_postgres = { version = "0.18", optional = true } +postgres-native-tls = { version = "0.5", optional = true } +native-tls = { version = "0.2", optional = true } rusqlite = { version = "0.29.0", optional = true } r2d2_sqlite = { version = "0.22.0", optional = true } diff --git a/contrib/sync_db_pools/lib/src/config.rs b/contrib/sync_db_pools/lib/src/config.rs index a0938e27e6..b30d477a87 100644 --- a/contrib/sync_db_pools/lib/src/config.rs +++ b/contrib/sync_db_pools/lib/src/config.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use rocket::{Rocket, Build}; use rocket::figment::{self, Figment, providers::Serialized}; @@ -21,7 +23,8 @@ use serde::{Serialize, Deserialize}; /// Config { /// url: "postgres://root:root@localhost/my_database".into(), /// pool_size: 10, -/// timeout: 5 +/// timeout: 5, +/// tls: None, /// }; /// ``` /// @@ -39,6 +42,33 @@ pub struct Config { /// Defaults to `5`. // FIXME: Use `time`. pub timeout: u8, + /// TLS configuration. + pub tls: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct TlsConfig { + /// Allow TLS connections with invalid certificates. + /// + /// _Default:_ `false`. + pub accept_invalid_certs: bool, + /// Allow TLS connections with invalid hostnames. + /// + /// _Default:_ `false`. + pub accept_invalid_hostnames: bool, + /// Sets the name of a file containing SSL certificate authority (CA) certificate(s). + /// If the file exists, the server’s certificate will be verified to be signed by one of these authorities. + /// + /// _Default:_ `None`. + pub ssl_root_cert: Option, + /// Sets the name of a file containing SSL client certificate. + /// + /// _Default:_ `None`. + pub ssl_client_cert: Option, + /// Sets the name of a file containing SSL client key. + /// + /// _Default:_ `None`. + pub ssl_client_key: Option, } impl Config { @@ -107,10 +137,18 @@ impl Config { .map(|workers| workers * 4) .ok(); - let figment = Figment::from(rocket.figment()) + let mut figment = Figment::from(rocket.figment()) .focus(&db_key) .join(Serialized::default("timeout", 5)); + if figment.find_value("tls").is_ok() { + figment = figment.join(Serialized::default("tls.accept_invalid_certs", false)) + .join(Serialized::default("tls.accept_invalid_hostnames", false)) + .join(Serialized::default("tls.ssl_root_cert", None::)) + .join(Serialized::default("tls.ssl_client_cert", None::)) + .join(Serialized::default("tls.ssl_client_key", None::)); + } + match default_pool_size { Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)), None => figment diff --git a/contrib/sync_db_pools/lib/src/connection.rs b/contrib/sync_db_pools/lib/src/connection.rs index 73c6913b48..904ad0f095 100644 --- a/contrib/sync_db_pools/lib/src/connection.rs +++ b/contrib/sync_db_pools/lib/src/connection.rs @@ -88,6 +88,8 @@ impl ConnectionPool { Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket), Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket), Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket), + Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket), + Err(Error::Tls(e)) => dberr!("tls", db, "{:?}", e, rocket), } }).await }) diff --git a/contrib/sync_db_pools/lib/src/error.rs b/contrib/sync_db_pools/lib/src/error.rs index fbf179e2a0..351f8a07ff 100644 --- a/contrib/sync_db_pools/lib/src/error.rs +++ b/contrib/sync_db_pools/lib/src/error.rs @@ -14,6 +14,10 @@ pub enum Error { Pool(r2d2::Error), /// An error occurred while extracting a `figment` configuration. Config(figment::Error), + /// An IO error occurred. + Io(std::io::Error), + /// A TLS error occurred. + Tls(Box), } impl From for Error { @@ -27,3 +31,9 @@ impl From for Error { Error::Pool(error) } } + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Error::Io(error) + } +} diff --git a/contrib/sync_db_pools/lib/src/poolable.rs b/contrib/sync_db_pools/lib/src/poolable.rs index 0451de60bb..337862e531 100644 --- a/contrib/sync_db_pools/lib/src/poolable.rs +++ b/contrib/sync_db_pools/lib/src/poolable.rs @@ -184,16 +184,277 @@ impl Poolable for diesel::MysqlConnection { } } -// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`. +#[cfg(feature = "postgres_pool")] +pub mod pg { + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::io; + + #[derive(Clone)] + pub enum MaybeTlsConnector { + NoTls(postgres::tls::NoTls), + #[cfg(feature = "postgres_pool_tls")] + Tls(postgres_native_tls::MakeTlsConnector) + } + + impl postgres::tls::MakeTlsConnect for MaybeTlsConnector { + type Stream = MaybeTlsConnector_Stream; + type TlsConnect = MaybeTlsConnector_TlsConnect; + type Error = MaybeTlsConnector_Error; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeTlsConnector::NoTls(connector) => { + > + ::make_tls_connect(connector, domain) + .map(Self::TlsConnect::NoTls) + .map_err(Self::Error::NoTls) + }, + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector::Tls(connector) => { + < + postgres_native_tls::MakeTlsConnector as + postgres::tls::MakeTlsConnect + >::make_tls_connect(connector, domain) + .map(Self::TlsConnect::Tls) + .map_err(Self::Error::Tls) + }, + } + } + } + + // --- Stream --- + + #[allow(non_camel_case_types)] + pub enum MaybeTlsConnector_Stream { + NoTls(postgres::tls::NoTlsStream), + #[cfg(feature = "postgres_pool_tls")] + Tls(postgres_native_tls::TlsStream) + } + + impl postgres::tls::TlsStream for MaybeTlsConnector_Stream { + fn channel_binding(&self) -> postgres::tls::ChannelBinding { + match self { + MaybeTlsConnector_Stream::NoTls(stream) => stream.channel_binding(), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Stream::Tls(stream) => stream.channel_binding(), + } + } + } + + impl tokio::io::AsyncRead for MaybeTlsConnector_Stream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_> + ) -> Poll> { + match *self { + MaybeTlsConnector_Stream::NoTls(ref mut stream) => + Pin::new(stream).poll_read(cx, buf), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Stream::Tls(ref mut stream) => + Pin::new(stream).poll_read(cx, buf), + } + } + } + + impl tokio::io::AsyncWrite for MaybeTlsConnector_Stream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8] + ) -> Poll> { + match *self { + MaybeTlsConnector_Stream::NoTls(ref mut stream) => + Pin::new(stream).poll_write(cx, buf), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Stream::Tls(ref mut stream) => + Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_flush(cx), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + MaybeTlsConnector_Stream::NoTls(ref mut stream) => + Pin::new(stream).poll_shutdown(cx), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Stream::Tls(ref mut stream) => + Pin::new(stream).poll_shutdown(cx), + } + } + } + + // --- TlsConnect --- + + #[allow(non_camel_case_types)] + pub enum MaybeTlsConnector_TlsConnect { + NoTls(postgres::tls::NoTls), + #[cfg(feature = "postgres_pool_tls")] + Tls(postgres_native_tls::TlsConnector) + } + + impl postgres::tls::TlsConnect for MaybeTlsConnector_TlsConnect { + type Stream = MaybeTlsConnector_Stream; + type Error = MaybeTlsConnector_Error; + type Future = MaybeTlsConnector_Future; + + fn connect(self, socket: postgres::Socket) -> Self::Future { + match self { + MaybeTlsConnector_TlsConnect::NoTls(connector) => + Self::Future::NoTls(connector.connect(socket)), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_TlsConnect::Tls(connector) => + Self::Future::Tls(connector.connect(socket)), + } + } + } + + // --- Error --- + + #[allow(non_camel_case_types)] + #[derive(Debug)] + pub enum MaybeTlsConnector_Error { + NoTls(postgres::tls::NoTlsError), + #[cfg(feature = "postgres_pool_tls")] + Tls(native_tls::Error) + } + + impl std::fmt::Display for MaybeTlsConnector_Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MaybeTlsConnector_Error::NoTls(e) => e.fmt(f), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Error::Tls(e) => e.fmt(f), + } + } + } + + impl std::error::Error for MaybeTlsConnector_Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + MaybeTlsConnector_Error::NoTls(e) => e.source(), + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Error::Tls(e) => e.source(), + } + } + } + + // --- Future --- + + #[allow(non_camel_case_types)] + pub enum MaybeTlsConnector_Future { + NoTls(postgres::tls::NoTlsFuture), + #[cfg(feature = "postgres_pool_tls")] + Tls(>::Future) + } + + impl std::future::Future for MaybeTlsConnector_Future { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match *self { + MaybeTlsConnector_Future::NoTls(ref mut future) => { + Pin::new(future) + .poll(cx) + .map(|v| v.map(MaybeTlsConnector_Stream::NoTls)) + .map_err(MaybeTlsConnector_Error::NoTls) + }, + #[cfg(feature = "postgres_pool_tls")] + MaybeTlsConnector_Future::Tls(ref mut future) => { + Pin::new(future) + .poll(cx) + .map(|v| v.map(MaybeTlsConnector_Stream::Tls)) + .map_err(MaybeTlsConnector_Error::Tls) + } + } + } + } +} + #[cfg(feature = "postgres_pool")] impl Poolable for postgres::Client { - type Manager = r2d2_postgres::PostgresConnectionManager; + type Manager = r2d2_postgres::PostgresConnectionManager; type Error = postgres::Error; fn pool(db_name: &str, rocket: &Rocket) -> PoolResult { let config = Config::from(db_name, rocket)?; let url = config.url.parse().map_err(Error::Custom)?; - let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls); + + let tls_connector = match config.tls { + // `tls_config` is unused when `postgres_pool_tls` is disabled. + #[allow(unused_variables)] + Some(ref tls_config) => { + + #[cfg(feature = "postgres_pool_tls")] + { + let mut connector_builder = native_tls::TlsConnector::builder(); + if let Some(ref cert) = tls_config.ssl_root_cert { + let cert_file_bytes = std::fs::read(cert)?; + let cert = native_tls::Certificate::from_pem(&cert_file_bytes) + .map_err(|e| Error::Tls(e.into()))?; + connector_builder.add_root_certificate(cert); + + // Client certs + match ( + tls_config.ssl_client_cert.as_ref(), + tls_config.ssl_client_key.as_ref(), + ) { + (Some(cert), Some(key)) => { + let cert_file_bytes = std::fs::read(cert)?; + let key_file_bytes = std::fs::read(key)?; + let cert = native_tls::Identity::from_pkcs8( + &cert_file_bytes, + &key_file_bytes + ).map_err(|e| Error::Tls(e.into()))?; + connector_builder.identity(cert); + }, + (Some(_), None) => { + return Err(Error::Tls( + "Client certificate provided without client key".into() + )) + }, + (None, Some(_)) => { + return Err(Error::Tls( + "Client key provided without client certificate".into() + )) + }, + (None, None) => {}, + } + } + + connector_builder + .danger_accept_invalid_certs(tls_config.accept_invalid_certs); + connector_builder + .danger_accept_invalid_hostnames(tls_config.accept_invalid_hostnames); + + pg::MaybeTlsConnector::Tls(postgres_native_tls::MakeTlsConnector::new( + connector_builder.build().map_err(|e| Error::Tls(e.into()))? + )) + } + + #[cfg(not(feature = "postgres_pool_tls"))] + { + // TODO: Should this be an error? + rocket::warn!("The `postgres_pool_tls` feature is disabled. \ + Postgres TLS configuration will be ignored."); + pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls) + } + }, + None => { + pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls) + } + }; + + let manager = r2d2_postgres::PostgresConnectionManager::new(url, tls_connector); let pool = r2d2::Pool::builder() .max_size(config.pool_size) .connection_timeout(Duration::from_secs(config.timeout as u64)) diff --git a/scripts/test.sh b/scripts/test.sh index 40525a1f05..e95f68f45e 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -90,6 +90,7 @@ function test_contrib() { diesel_sqlite_pool diesel_mysql_pool postgres_pool + postgres_pool_tls sqlite_pool memcache_pool )