@@ -7,14 +7,13 @@ use api::types::{
77 ListKeyVersionsRequest , ListKeyVersionsResponse , PutObjectRequest , PutObjectResponse ,
88} ;
99use async_trait:: async_trait;
10- use bb8_postgres:: bb8:: Pool ;
11- use bb8_postgres:: PostgresConnectionManager ;
1210use bytes:: Bytes ;
1311use chrono:: Utc ;
1412use native_tls:: TlsConnector ;
1513use postgres_native_tls:: MakeTlsConnector ;
1614use std:: cmp:: min;
1715use std:: io:: { self , Error , ErrorKind } ;
16+ use tokio:: sync:: Mutex ;
1817use tokio_postgres:: tls:: { MakeTlsConnect , TlsConnect } ;
1918use tokio_postgres:: { error, Client , NoTls , Socket , Transaction } ;
2019
@@ -47,6 +46,72 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100;
4746/// Exceeding this value will result in request rejection through [`VssError::InvalidRequestError`].
4847pub const MAX_PUT_REQUEST_ITEM_COUNT : usize = 1000 ;
4948
49+ const POOL_SIZE : usize = 10 ;
50+
51+ struct SmallPool < T > {
52+ connections : [ Mutex < Client > ; POOL_SIZE ] ,
53+ endpoint : String ,
54+ db_name : String ,
55+ tls : T ,
56+ }
57+
58+ impl < T > SmallPool < T >
59+ where
60+ T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
61+ T :: Stream : Send + Sync ,
62+ T :: TlsConnect : Send ,
63+ <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
64+ {
65+ async fn new ( postgres_endpoint : & str , vss_db : & str , tls : T ) -> Result < Self , Error > {
66+ let connections = [
67+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
68+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
69+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
70+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
71+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
72+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
73+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
74+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
75+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
76+ Mutex :: new ( make_db_connection ( postgres_endpoint, vss_db, tls. clone ( ) ) . await ?) ,
77+ ] ;
78+
79+ let pool = SmallPool {
80+ connections,
81+ endpoint : String :: from ( postgres_endpoint) ,
82+ db_name : String :: from ( vss_db) ,
83+ tls,
84+ } ;
85+ Ok ( pool)
86+ }
87+
88+ async fn get ( & self ) -> Result < tokio:: sync:: MutexGuard < ' _ , Client > , Error > {
89+ let mut conn = tokio:: select! {
90+ conn_0 = self . connections[ 0 ] . lock( ) => conn_0,
91+ conn_1 = self . connections[ 1 ] . lock( ) => conn_1,
92+ conn_2 = self . connections[ 2 ] . lock( ) => conn_2,
93+ conn_3 = self . connections[ 3 ] . lock( ) => conn_3,
94+ conn_4 = self . connections[ 4 ] . lock( ) => conn_4,
95+ conn_5 = self . connections[ 5 ] . lock( ) => conn_5,
96+ conn_6 = self . connections[ 6 ] . lock( ) => conn_6,
97+ conn_7 = self . connections[ 7 ] . lock( ) => conn_7,
98+ conn_8 = self . connections[ 8 ] . lock( ) => conn_8,
99+ conn_9 = self . connections[ 9 ] . lock( ) => conn_9,
100+ } ;
101+ self . ensure_connected ( & mut conn) . await ?;
102+ Ok ( conn)
103+ }
104+
105+ async fn ensure_connected ( & self , client : & mut Client ) -> Result < ( ) , Error > {
106+ if client. is_closed ( ) || client. check_connection ( ) . await . is_err ( ) {
107+ let new_client =
108+ make_db_connection ( & self . endpoint , & self . db_name , self . tls . clone ( ) ) . await ?;
109+ * client = new_client;
110+ }
111+ Ok ( ( ) )
112+ }
113+ }
114+
50115/// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
51116pub struct PostgresBackend < T >
52117where
55120 <T as MakeTlsConnect < Socket > >:: TlsConnect : Send ,
56121 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
57122{
58- pool : Pool < PostgresConnectionManager < T > > ,
123+ pool : SmallPool < T > ,
59124}
60125
61126/// A postgres backend with plaintext connections to the database
@@ -183,22 +248,8 @@ where
183248 postgres_endpoint : & str , default_db : & str , vss_db : & str , tls : T ,
184249 ) -> Result < Self , Error > {
185250 create_database ( postgres_endpoint, default_db, vss_db, tls. clone ( ) ) . await ?;
186- let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, vss_db) ;
187- let manager =
188- PostgresConnectionManager :: new_from_stringlike ( vss_dsn, tls) . map_err ( |e| {
189- Error :: new (
190- ErrorKind :: Other ,
191- format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
192- )
193- } ) ?;
194- // By default, Pool maintains 0 long-running connections, so returning a pool
195- // here is no guarantee that Pool established a connection to the database.
196- //
197- // See Builder::min_idle to increase the long-running connection count.
198- let pool = Pool :: builder ( )
199- . build ( manager)
200- . await
201- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
251+
252+ let pool = SmallPool :: new ( postgres_endpoint, vss_db, tls) . await ?;
202253 let postgres_backend = PostgresBackend { pool } ;
203254
204255 #[ cfg( not( test) ) ]
@@ -208,10 +259,7 @@ where
208259 }
209260
210261 async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
211- let mut conn = self . pool . get ( ) . await . map_err ( |e| {
212- Error :: new ( ErrorKind :: Other , format ! ( "Failed to fetch a connection from Pool: {}" , e) )
213- } ) ?;
214-
262+ let mut conn = self . pool . get ( ) . await ?;
215263 // Get the next migration to be applied.
216264 let migration_start = match conn. query_one ( GET_VERSION_STMT , & [ ] ) . await {
217265 Ok ( row) => {
@@ -464,11 +512,7 @@ where
464512 async fn get (
465513 & self , user_token : String , request : GetObjectRequest ,
466514 ) -> Result < GetObjectResponse , VssError > {
467- let conn = self
468- . pool
469- . get ( )
470- . await
471- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
515+ let conn = self . pool . get ( ) . await ?;
472516 let stmt = "SELECT key, value, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key = $3" ;
473517 let row = conn
474518 . query_opt ( stmt, & [ & user_token, & request. store_id , & request. key ] )
@@ -525,11 +569,7 @@ where
525569 vss_put_records. push ( global_version_record) ;
526570 }
527571
528- let mut conn = self
529- . pool
530- . get ( )
531- . await
532- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
572+ let mut conn = self . pool . get ( ) . await ?;
533573 let transaction = conn
534574 . transaction ( )
535575 . await
@@ -573,11 +613,7 @@ where
573613 } ) ?;
574614 let vss_record = self . build_vss_record ( user_token, store_id, key_value) ;
575615
576- let mut conn = self
577- . pool
578- . get ( )
579- . await
580- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
616+ let mut conn = self . pool . get ( ) . await ?;
581617 let transaction = conn
582618 . transaction ( )
583619 . await
@@ -622,11 +658,7 @@ where
622658
623659 let limit = min ( page_size, LIST_KEY_VERSIONS_MAX_PAGE_SIZE ) as i64 ;
624660
625- let conn = self
626- . pool
627- . get ( )
628- . await
629- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
661+ let conn = self . pool . get ( ) . await ?;
630662
631663 let stmt = "SELECT key, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key > $3 AND key LIKE $4 ORDER BY key LIMIT $5" ;
632664
0 commit comments