@@ -16,7 +16,9 @@ use lightning::util::persist::{
1616 KVStore , KVStoreSync , PageToken , PaginatedKVStore , PaginatedKVStoreSync , PaginatedListResponse ,
1717} ;
1818use lightning_types:: string:: PrintableString ;
19- use tokio_postgres:: { connect, Client , Config , Error as PgError , NoTls } ;
19+ use native_tls:: TlsConnector ;
20+ use postgres_native_tls:: MakeTlsConnector ;
21+ use tokio_postgres:: { Client , Config , Error as PgError , NoTls } ;
2022
2123use crate :: io:: utils:: check_namespace_key_validity;
2224
@@ -65,7 +67,17 @@ impl PostgresStore {
6567 /// if it doesn't already exist.
6668 ///
6769 /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`].
68- pub fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
70+ ///
71+ /// If `tls_config` is `Some`, TLS will be used for database connections. A custom CA
72+ /// certificate can be provided via [`PostgresTlsConfig::certificate_pem`], otherwise the
73+ /// system's default root certificates are used. If `tls_config` is `None`, connections
74+ /// will be unencrypted.
75+ pub fn new (
76+ connection_string : String , kv_table_name : Option < String > ,
77+ tls_config : Option < PostgresTlsConfig > ,
78+ ) -> io:: Result < Self > {
79+ let tls = Self :: build_tls_connector ( tls_config) ?;
80+
6981 let internal_runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
7082 . enable_all ( )
7183 . thread_name_fn ( || {
@@ -79,15 +91,41 @@ impl PostgresStore {
7991 . unwrap ( ) ;
8092
8193 let inner = tokio:: task:: block_in_place ( || {
82- internal_runtime
83- . block_on ( async { PostgresStoreInner :: new ( connection_string, kv_table_name) . await } )
94+ internal_runtime. block_on ( async {
95+ PostgresStoreInner :: new ( connection_string, kv_table_name, tls) . await
96+ } )
8497 } ) ?;
8598
8699 let inner = Arc :: new ( inner) ;
87100 let next_write_version = AtomicU64 :: new ( 1 ) ;
88101 Ok ( Self { inner, next_write_version, internal_runtime : Some ( internal_runtime) } )
89102 }
90103
104+ fn build_tls_connector ( tls_config : Option < PostgresTlsConfig > ) -> io:: Result < PgTlsConnector > {
105+ match tls_config {
106+ Some ( config) => {
107+ let mut builder = TlsConnector :: builder ( ) ;
108+ if let Some ( pem) = config. certificate_pem {
109+ let crt = native_tls:: Certificate :: from_pem ( pem. as_bytes ( ) ) . map_err ( |e| {
110+ io:: Error :: new (
111+ io:: ErrorKind :: InvalidInput ,
112+ format ! ( "Failed to parse PEM certificate: {e}" ) ,
113+ )
114+ } ) ?;
115+ builder. add_root_certificate ( crt) ;
116+ }
117+ let connector = builder. build ( ) . map_err ( |e| {
118+ io:: Error :: new (
119+ io:: ErrorKind :: Other ,
120+ format ! ( "Failed to build TLS connector: {e}" ) ,
121+ )
122+ } ) ?;
123+ Ok ( PgTlsConnector :: NativeTls ( MakeTlsConnector :: new ( connector) ) )
124+ } ,
125+ None => Ok ( PgTlsConnector :: Plain ) ,
126+ }
127+ }
128+
91129 fn build_locking_key (
92130 & self , primary_namespace : & str , secondary_namespace : & str , key : & str ,
93131 ) -> String {
@@ -309,14 +347,17 @@ impl PaginatedKVStore for PostgresStore {
309347
310348struct PostgresStoreInner {
311349 client : tokio:: sync:: Mutex < Client > ,
312- connection_string : String ,
350+ config : Config ,
313351 kv_table_name : String ,
352+ tls : PgTlsConnector ,
314353 write_version_locks : Mutex < HashMap < String , Arc < tokio:: sync:: Mutex < u64 > > > > ,
315354 next_sort_order : AtomicI64 ,
316355}
317356
318357impl PostgresStoreInner {
319- async fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
358+ async fn new (
359+ connection_string : String , kv_table_name : Option < String > , tls : PgTlsConnector ,
360+ ) -> io:: Result < Self > {
320361 let kv_table_name = kv_table_name. unwrap_or ( DEFAULT_KV_TABLE_NAME . to_string ( ) ) ;
321362
322363 // If a dbname is specified in the connection string, ensure the database exists
@@ -327,10 +368,10 @@ impl PostgresStoreInner {
327368 } ) ?;
328369
329370 if let Some ( db_name) = config. get_dbname ( ) {
330- Self :: create_database_if_not_exists ( & connection_string , db_name) . await ?;
371+ Self :: create_database_if_not_exists ( & config , db_name, & tls ) . await ?;
331372 }
332373
333- let client = Self :: make_connection ( & connection_string ) . await ?;
374+ let client = Self :: make_config_connection ( & config , & tls ) . await ?;
334375
335376 // Create the KV data table if it doesn't exist.
336377 let sql = format ! (
@@ -399,29 +440,17 @@ impl PostgresStoreInner {
399440
400441 let client = tokio:: sync:: Mutex :: new ( client) ;
401442 let write_version_locks = Mutex :: new ( HashMap :: new ( ) ) ;
402- Ok ( Self { client, connection_string , kv_table_name, write_version_locks, next_sort_order } )
443+ Ok ( Self { client, config , kv_table_name, tls , write_version_locks, next_sort_order } )
403444 }
404445
405446 async fn create_database_if_not_exists (
406- connection_string : & str , db_name : & str ,
447+ config : & Config , db_name : & str , tls : & PgTlsConnector ,
407448 ) -> io:: Result < ( ) > {
408449 // Connect without a dbname (to the default database) so we can create the target.
409- let mut config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
410- let msg = format ! ( "Failed to parse PostgreSQL connection string: {e}" ) ;
411- io:: Error :: new ( io:: ErrorKind :: InvalidInput , msg)
412- } ) ?;
450+ let mut config = config. clone ( ) ;
413451 config. dbname ( "postgres" ) ;
414452
415- let ( client, connection) = config. connect ( NoTls ) . await . map_err ( |e| {
416- let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
417- io:: Error :: new ( io:: ErrorKind :: Other , msg)
418- } ) ?;
419-
420- tokio:: spawn ( async move {
421- if let Err ( e) = connection. await {
422- log:: error!( "PostgreSQL connection error: {e}" ) ;
423- }
424- } ) ;
453+ let client = Self :: make_config_connection ( & config, tls) . await ?;
425454
426455 let row = client
427456 . query_opt ( "SELECT 1 FROM pg_database WHERE datname = $1" , & [ & db_name] )
@@ -443,27 +472,41 @@ impl PostgresStoreInner {
443472 Ok ( ( ) )
444473 }
445474
446- async fn make_connection ( connection_string : & str ) -> io:: Result < Client > {
447- let ( client , connection ) = connect ( connection_string , NoTls ) . await . map_err ( |e| {
475+ async fn make_config_connection ( config : & Config , tls : & PgTlsConnector ) -> io:: Result < Client > {
476+ let err_map = |e| {
448477 let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
449478 io:: Error :: new ( io:: ErrorKind :: Other , msg)
450- } ) ?;
451-
452- tokio:: spawn ( async move {
453- if let Err ( e) = connection. await {
454- log:: error!( "PostgreSQL connection error: {e}" ) ;
455- }
456- } ) ;
479+ } ;
457480
458- Ok ( client)
481+ match tls {
482+ PgTlsConnector :: Plain => {
483+ let ( client, connection) = config. connect ( NoTls ) . await . map_err ( err_map) ?;
484+ tokio:: spawn ( async move {
485+ if let Err ( e) = connection. await {
486+ log:: error!( "PostgreSQL connection error: {e}" ) ;
487+ }
488+ } ) ;
489+ Ok ( client)
490+ } ,
491+ PgTlsConnector :: NativeTls ( tls_connector) => {
492+ let ( client, connection) =
493+ config. connect ( tls_connector. clone ( ) ) . await . map_err ( err_map) ?;
494+ tokio:: spawn ( async move {
495+ if let Err ( e) = connection. await {
496+ log:: error!( "PostgreSQL connection error: {e}" ) ;
497+ }
498+ } ) ;
499+ Ok ( client)
500+ } ,
501+ }
459502 }
460503
461504 async fn ensure_connected (
462505 & self , client : & mut tokio:: sync:: MutexGuard < ' _ , Client > ,
463506 ) -> io:: Result < ( ) > {
464507 if client. is_closed ( ) || client. check_connection ( ) . await . is_err ( ) {
465508 log:: debug!( "Reconnecting to PostgreSQL database" ) ;
466- let new_client = Self :: make_connection ( & self . connection_string ) . await ?;
509+ let new_client = Self :: make_config_connection ( & self . config , & self . tls ) . await ?;
467510 * * client = new_client;
468511 }
469512 Ok ( ( ) )
@@ -750,6 +793,19 @@ impl PostgresStoreInner {
750793 }
751794}
752795
796+ /// TLS configuration for PostgreSQL connections.
797+ #[ derive( Debug , Clone ) ]
798+ pub struct PostgresTlsConfig {
799+ /// PEM-encoded CA certificate. If `None`, the system's default root certificates are used.
800+ pub certificate_pem : Option < String > ,
801+ }
802+
803+ #[ derive( Clone ) ]
804+ enum PgTlsConnector {
805+ Plain ,
806+ NativeTls ( MakeTlsConnector ) ,
807+ }
808+
753809#[ cfg( test) ]
754810mod tests {
755811 use super :: * ;
@@ -761,7 +817,7 @@ mod tests {
761817 }
762818
763819 fn create_test_store ( table_name : & str ) -> PostgresStore {
764- PostgresStore :: new ( test_connection_string ( ) , Some ( table_name. to_string ( ) ) ) . unwrap ( )
820+ PostgresStore :: new ( test_connection_string ( ) , Some ( table_name. to_string ( ) ) , None ) . unwrap ( )
765821 }
766822
767823 fn cleanup_store ( store : & PostgresStore ) {
@@ -1092,4 +1148,25 @@ mod tests {
10921148 cleanup_store ( & store) ;
10931149 }
10941150 }
1151+
1152+ #[ test]
1153+ fn test_tls_config_none_builds_plain_connector ( ) {
1154+ let connector = PostgresStore :: build_tls_connector ( None ) . unwrap ( ) ;
1155+ assert ! ( matches!( connector, PgTlsConnector :: Plain ) ) ;
1156+ }
1157+
1158+ #[ test]
1159+ fn test_tls_config_system_certs_builds_native_tls_connector ( ) {
1160+ let config = Some ( PostgresTlsConfig { certificate_pem : None } ) ;
1161+ let connector = PostgresStore :: build_tls_connector ( config) . unwrap ( ) ;
1162+ assert ! ( matches!( connector, PgTlsConnector :: NativeTls ( _) ) ) ;
1163+ }
1164+
1165+ #[ test]
1166+ fn test_tls_config_invalid_pem_returns_error ( ) {
1167+ let config =
1168+ Some ( PostgresTlsConfig { certificate_pem : Some ( "not-a-valid-pem" . to_string ( ) ) } ) ;
1169+ let result = PostgresStore :: build_tls_connector ( config) ;
1170+ assert ! ( result. is_err( ) ) ;
1171+ }
10951172}
0 commit comments