55from multiprocessing import Lock , Value
66from pathlib import Path
77from types import ModuleType
8- from typing import Any , Dict , List , Union
8+ from typing import Any , Callable , Dict , List , Tuple , Union
99
1010from _tool_helpers import print_debug , print_error , print_info , print_warning
1111
@@ -49,14 +49,14 @@ def __init__(
4949
5050 self .aws_wrapper : ModuleType | None = None
5151 self .cx_oracle : ModuleType | None = None
52- self .mssql_connect : ModuleType | None = None
52+ self .mssql_connect : Callable [..., Any ] | None = None
5353 self .mysql_connector : ModuleType | None = None
5454 self .psycopg_connection : ModuleType | None = None
5555 self .psycopg2 : ModuleType | None = None
5656 self .pyodbc : ModuleType | None = None
5757 self .sqlite3 : ModuleType | None = None
58- self .imported_aws_wrapper = False
5958
59+ self .imported_aws_wrapper = False
6060 self .imported_aws_wrapper_error : ImportError | None = None
6161 self .imported_cx_oracle = False
6262 self .imported_cx_oracle_error : ImportError | None = None
@@ -84,7 +84,7 @@ def __init__(
8484 if not (connection_string := engine_configuration .get ("SQL" ).get ("CONNECTION" )):
8585 raise LookupError (f"Database connection settings appear to be missing or invalid" )
8686
87- self .main_db_type , _ = connection_string .split ("://" ) if "://" in connection_string else ("UNKNOWN_DB" )
87+ self .main_db_type , _ = connection_string .split ("://" ) if "://" in connection_string else ("UNKNOWN_DB" , "" )
8888 if (main_db_type_upper := self .main_db_type .upper ()) not in SUPPORTED_DBS :
8989 raise LookupError (f"Unsupported database type: { self .main_db_type } " )
9090
@@ -94,7 +94,7 @@ def __init__(
9494 self .connect ("MAIN" )
9595
9696 for table_name in engine_configuration .get ("HYBRID" , {}).keys ():
97- node = engine_configuration ["HYBRID" ][table_name ]
97+ node = engine_configuration ["HYBRID" ][table_name ] # type: ignore
9898 if node not in self .connections :
9999 self .connections [node ] = {}
100100 # self.connect(node, engine_configuration[node]["DB_1"])
@@ -256,7 +256,7 @@ def connect(self, node: str) -> None:
256256 conn_str , self .connections [node ]["query_params" ], ["driver" ]
257257 )
258258
259- self .connections [node ]["dbo" ] = self .mssql_connect (conn_str )
259+ self .connections [node ]["dbo" ] = self .mssql_connect (conn_str ) # type: ignore
260260
261261 # pyodbc for engine config: CONNECTION=mssql://username:password@database
262262 else :
@@ -311,7 +311,7 @@ def connect(self, node: str) -> None:
311311 except Exception as err :
312312 raise Exception (err )
313313
314- def set_node (self , sql : str ) -> Union [ str , List [ str ]] :
314+ def set_node (self , sql : str ) -> str :
315315 if len (self .connections ) == 1 :
316316 return "MAIN"
317317
@@ -360,14 +360,16 @@ def close(self) -> None:
360360 with SzDatabase ._aurora_clean_up_msg_lock :
361361 if not SzDatabase ._aurora_clean_up_msg_flag .value :
362362 SzDatabase ._aurora_clean_up_msg_flag .value = True
363- print_info ("Cleaning up Aurora database resources, this can take a minute or two..." , end_str = "\n " )
363+ print_info (
364+ "Cleaning up Aurora database resources, this can take a minute or two..." , end_str = "\n "
365+ )
364366
365367 # Release resources for all Aurora nodes with IAM auth
366368 for node in self .connections .keys ():
367369 if self .connections [node ]["dbtype" ] == "AURORAPOSTGRESQL" and self .connections [node ]["iam_auth" ]:
368370 self .connections [node ]["dbo" ].release_resources ()
369371
370- def sql_prep (self , sql : str , return_node : bool = False ) -> str : # left in for backwards compatibility
372+ def sql_prep (self , sql : str , return_node : bool = False ) -> Union [ str , Tuple [ str , str ]]:
371373 node = self .set_node (sql )
372374
373375 if (
@@ -392,7 +394,7 @@ def sql_exec(self, raw_sql: str, param_list=None, **kwargs):
392394 sql = self .statement_cache [raw_sql ]["sql" ]
393395 node = self .statement_cache [raw_sql ]["node" ]
394396 else :
395- sql , node = self .sql_prep (raw_sql , return_node = True )
397+ sql , node = self .sql_prep (raw_sql , return_node = True ) # type: ignore
396398 self .statement_cache [raw_sql ] = {"sql" : sql , "node" : node }
397399
398400 if param_list and type (param_list ) not in (list , tuple ):
@@ -483,7 +485,7 @@ def fetch_many_dicts(self, cursor_data, row_count):
483485
484486 def dburi_parse (self , node : str , db_uri : str ) -> None :
485487 """Parse the database URI"""
486- uri_dict = {}
488+ uri_dict : dict [ str , Any ] = {}
487489
488490 try :
489491 uri_dict ["TABLE" ] = uri_dict ["SCHEMA" ] = uri_dict ["PORT" ] = uri_dict ["DBURI_PARMS" ] = uri_dict [
@@ -746,7 +748,7 @@ def import_mysql_connector(self):
746748 def import_psycopg2 (self ):
747749 """Try and import psycopg2"""
748750 try :
749- import psycopg2
751+ import psycopg2 # type: ignore
750752
751753 self .psycopg2 = psycopg2
752754 self .imported_psycopg2 = True
@@ -776,7 +778,7 @@ def import_sqlite3(self):
776778 # pylint: enable=C0415
777779
778780 def append_uri_query_params (
779- self , string_to_append : str , params_dict : Dict [str , str ], keys_to_ignore : List [str ] = None
781+ self , string_to_append : str , params_dict : Dict [str , str ], keys_to_ignore : List [str ] | None = None
780782 ) -> str :
781783 """
782784 Add any query parameters to connection strings when specified, for example with:
0 commit comments