@@ -16,7 +16,7 @@ use lightning::util::persist::{
1616 KVStore , KVStoreSync , PageToken , PaginatedKVStore , PaginatedKVStoreSync , PaginatedListResponse ,
1717} ;
1818use lightning_types:: string:: PrintableString ;
19- use tokio_postgres:: NoTls ;
19+ use tokio_postgres:: { connect , Client , Config , Error as PgError , NoTls } ;
2020
2121use crate :: io:: utils:: check_namespace_key_validity;
2222
@@ -61,6 +61,9 @@ impl PostgresStore {
6161 ///
6262 /// Connects to the PostgreSQL database at the given `connection_string`.
6363 ///
64+ /// If the connection string includes a `dbname`, the database will be created automatically
65+ /// if it doesn't already exist.
66+ ///
6467 /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`].
6568 pub fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
6669 let internal_runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
@@ -306,7 +309,7 @@ impl PaginatedKVStore for PostgresStore {
306309}
307310
308311struct PostgresStoreInner {
309- client : tokio:: sync:: Mutex < tokio_postgres :: Client > ,
312+ client : tokio:: sync:: Mutex < Client > ,
310313 kv_table_name : String ,
311314 write_version_locks : Mutex < HashMap < String , Arc < tokio:: sync:: Mutex < u64 > > > > ,
312315 next_sort_order : AtomicI64 ,
@@ -316,11 +319,21 @@ impl PostgresStoreInner {
316319 async fn new ( connection_string : & str , kv_table_name : Option < String > ) -> io:: Result < Self > {
317320 let kv_table_name = kv_table_name. unwrap_or ( DEFAULT_KV_TABLE_NAME . to_string ( ) ) ;
318321
319- let ( client, connection) =
320- tokio_postgres:: connect ( connection_string, NoTls ) . await . map_err ( |e| {
321- let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
322- io:: Error :: new ( io:: ErrorKind :: Other , msg)
323- } ) ?;
322+ // If a dbname is specified in the connection string, ensure the database exists
323+ // by first connecting without a dbname and creating it if necessary.
324+ let config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
325+ let msg = format ! ( "Failed to parse PostgreSQL connection string: {e}" ) ;
326+ io:: Error :: new ( io:: ErrorKind :: InvalidInput , msg)
327+ } ) ?;
328+
329+ if let Some ( db_name) = config. get_dbname ( ) {
330+ Self :: create_database_if_not_exists ( connection_string, db_name) . await ?;
331+ }
332+
333+ let ( client, connection) = connect ( connection_string, NoTls ) . await . map_err ( |e| {
334+ let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
335+ io:: Error :: new ( io:: ErrorKind :: Other , msg)
336+ } ) ?;
324337
325338 // Spawn the connection task so it runs in the background.
326339 tokio:: spawn ( async move {
@@ -399,6 +412,47 @@ impl PostgresStoreInner {
399412 Ok ( Self { client, kv_table_name, write_version_locks, next_sort_order } )
400413 }
401414
415+ async fn create_database_if_not_exists (
416+ connection_string : & str , db_name : & str ,
417+ ) -> io:: Result < ( ) > {
418+ // Connect without a dbname (to the default database) so we can create the target.
419+ let mut config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
420+ let msg = format ! ( "Failed to parse PostgreSQL connection string: {e}" ) ;
421+ io:: Error :: new ( io:: ErrorKind :: InvalidInput , msg)
422+ } ) ?;
423+ config. dbname ( "postgres" ) ;
424+
425+ let ( client, connection) = config. connect ( NoTls ) . await . map_err ( |e| {
426+ let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
427+ io:: Error :: new ( io:: ErrorKind :: Other , msg)
428+ } ) ?;
429+
430+ tokio:: spawn ( async move {
431+ if let Err ( e) = connection. await {
432+ log:: error!( "PostgreSQL connection error: {e}" ) ;
433+ }
434+ } ) ;
435+
436+ let row = client
437+ . query_opt ( "SELECT 1 FROM pg_database WHERE datname = $1" , & [ & db_name] )
438+ . await
439+ . map_err ( |e| {
440+ let msg = format ! ( "Failed to check for database {db_name}: {e}" ) ;
441+ io:: Error :: new ( io:: ErrorKind :: Other , msg)
442+ } ) ?;
443+
444+ if row. is_none ( ) {
445+ let sql = format ! ( "CREATE DATABASE {db_name}" ) ;
446+ client. execute ( & sql, & [ ] ) . await . map_err ( |e| {
447+ let msg = format ! ( "Failed to create database {db_name}: {e}" ) ;
448+ io:: Error :: new ( io:: ErrorKind :: Other , msg)
449+ } ) ?;
450+ log:: info!( "Created database {db_name}" ) ;
451+ }
452+
453+ Ok ( ( ) )
454+ }
455+
402456 fn get_inner_lock_ref ( & self , locking_key : String ) -> Arc < tokio:: sync:: Mutex < u64 > > {
403457 let mut outer_lock = self . write_version_locks . lock ( ) . unwrap ( ) ;
404458 Arc :: clone ( & outer_lock. entry ( locking_key) . or_default ( ) )
0 commit comments