55from tempfile import NamedTemporaryFile
66from uuid import uuid4
77
8+ import boto3
89import psycopg2
910from psycopg2 .extras import Range
1011
2324
2425logger = logging .getLogger (__name__ )
2526
26- try :
27- import boto3
28-
29- IAM_ENABLED = True
30- except ImportError :
31- IAM_ENABLED = False
32-
3327types_map = {
3428 20 : TYPE_INTEGER ,
3529 21 : TYPE_INTEGER ,
@@ -177,6 +171,8 @@ def configuration_schema(cls):
177171 "sslrootcertFile" : {"type" : "string" , "title" : "SSL Root Certificate" },
178172 "sslcertFile" : {"type" : "string" , "title" : "SSL Client Certificate" },
179173 "sslkeyFile" : {"type" : "string" , "title" : "SSL Client Key" },
174+ "awsIamAuth" : {"type" : "boolean" , "title" : "AWS IAM authentication" },
175+ "awsRegion" : {"type" : "string" , "title" : "AWS Region" },
180176 },
181177 "order" : ["host" , "port" , "user" , "password" ],
182178 "required" : ["dbname" ],
@@ -186,6 +182,8 @@ def configuration_schema(cls):
186182 "sslrootcertFile" ,
187183 "sslcertFile" ,
188184 "sslkeyFile" ,
185+ "awsIamAuth" ,
186+ "awsRegion" ,
189187 ],
190188 }
191189
@@ -255,11 +253,27 @@ def _get_tables(self, schema):
255253 def _get_connection (self ):
256254 self .ssl_config = _get_ssl_config (self .configuration )
257255 self .dsn = _parse_dsn (self .configuration )
256+
257+ user = self .configuration .get ("user" )
258+ password = self .configuration .get ("password" )
259+ host = self .configuration .get ("host" )
260+ port = self .configuration .get ("port" , 5432 )
261+
262+ if self .configuration .get ("awsIamAuth" , False ):
263+ region_name = self .configuration .get ("awsRegion" )
264+ rds_client = boto3 .client ("rds" , region_name = region_name )
265+ auth_token = rds_client .generate_db_auth_token (
266+ DBHostname = host ,
267+ Port = port ,
268+ DBUsername = user ,
269+ )
270+ password = auth_token
271+
258272 connection = psycopg2 .connect (
259- user = self . configuration . get ( " user" ) ,
260- password = self . configuration . get ( " password" ) ,
261- host = self . configuration . get ( " host" ) ,
262- port = self . configuration . get ( " port" ) ,
273+ user = user ,
274+ password = password ,
275+ host = host ,
276+ port = port ,
263277 dbname = self .configuration .get ("dbname" ),
264278 async_ = True ,
265279 ** self .ssl_config ,
@@ -426,7 +440,7 @@ def name(cls):
426440
427441 @classmethod
428442 def enabled (cls ):
429- return IAM_ENABLED
443+ return True
430444
431445 def _login_method_selection (self ):
432446 if self .configuration .get ("rolename" ):
0 commit comments