22import csv
33import importlib
44import logging
5+ import math
56import re
67from collections .abc import Callable , Iterable
78from contextlib import suppress
1011from urllib .parse import quote_plus
1112
1213import psycopg2 .sql
13- from sqlalchemy import create_engine , text
14- from sqlalchemy .engine import LegacyRow , RowMapping
15- from sqlalchemy .engine .base import Connection , Engine
14+ from sqlalchemy import Connection , Engine , Row , create_engine , text
15+ from sqlalchemy .engine import RowMapping
1616from sqlalchemy .exc import ProgrammingError , SQLAlchemyError
17- from sqlalchemy .pool . base import _ConnectionFairy
17+ from sqlalchemy .pool import PoolProxiedConnection
1818
1919from testgen import settings
2020from testgen .common .credentials import (
3232 SQLFlavor ,
3333 resolve_connection_params ,
3434)
35- from testgen .common .standalone_postgres import get_connection_string as get_standalone_connection_string , is_standalone_mode
3635from testgen .common .read_file import get_template_files
36+ from testgen .common .standalone_postgres import get_connection_string as get_standalone_connection_string
37+ from testgen .common .standalone_postgres import is_standalone_mode
3738from testgen .utils import get_exception_message
3839
3940LOG = logging .getLogger ("testgen" )
@@ -103,12 +104,14 @@ def create_database(
103104) -> None :
104105 LOG .debug ("DB operation: create_database on App database (User type = database_admin)" )
105106
107+ # DDL like CREATE/DROP DATABASE cannot run inside a transaction.
108+ # Use AUTOCOMMIT isolation so each statement commits immediately.
106109 connection = _init_db_connection (
107110 user_override = params ["TESTGEN_ADMIN_USER" ],
108111 password_override = params ["TESTGEN_ADMIN_PASSWORD" ],
109112 user_type = "database_admin" ,
110113 )
111- connection . execute ( "commit " )
114+ connection = connection . execution_options ( isolation_level = "AUTOCOMMIT " )
112115
113116 with connection :
114117 if drop_existing :
@@ -118,20 +121,16 @@ def create_database(
118121 ),
119122 {"database_name" : database_name },
120123 )
121- connection .execute ("commit" )
122- connection .execute (f"DROP DATABASE IF EXISTS { database_name } " )
123- connection .execute ("commit" )
124+ connection .execute (text (f"DROP DATABASE IF EXISTS { database_name } " ))
124125 if drop_users_and_roles :
125126 if user := params .get ("TESTGEN_USER" ):
126- connection .execute (f"DROP USER IF EXISTS { user } " )
127+ connection .execute (text ( f"DROP USER IF EXISTS { user } " ) )
127128 if report_user := params .get ("TESTGEN_REPORT_USER" ):
128- connection .execute (f"DROP USER IF EXISTS { report_user } " )
129- connection .execute ("DROP ROLE IF EXISTS testgen_execute_role" )
130- connection .execute ("DROP ROLE IF EXISTS testgen_report_role" )
131- connection .execute ("commit" )
129+ connection .execute (text (f"DROP USER IF EXISTS { report_user } " ))
130+ connection .execute (text ("DROP ROLE IF EXISTS testgen_execute_role" ))
131+ connection .execute (text ("DROP ROLE IF EXISTS testgen_report_role" ))
132132 with suppress (ProgrammingError ):
133- connection .execute (f"CREATE DATABASE { database_name } " )
134- connection .close ()
133+ connection .execute (text (f"CREATE DATABASE { database_name } " ))
135134
136135
137136def execute_db_queries (
@@ -150,7 +149,6 @@ def execute_db_queries(
150149 LOG .debug ("No queries to process" )
151150 for index , (query , params ) in enumerate (queries ):
152151 LOG .debug (f"Query { index + 1 } of { len (queries )} : { query } " )
153- transaction = connection .begin ()
154152 result = connection .execute (text (query ), params )
155153 row_counts .append (result .rowcount )
156154 if result .rowcount == - 1 :
@@ -163,7 +161,7 @@ def execute_db_queries(
163161 except Exception :
164162 return_values .append (None )
165163
166- transaction .commit ()
164+ connection .commit ()
167165 LOG .debug (message )
168166
169167 return return_values , row_counts
@@ -180,28 +178,28 @@ def fetch_from_db_threaded(
180178 use_target_db : bool = False ,
181179 max_threads : int = 4 ,
182180 progress_callback : Callable [[ThreadedProgress ], None ] | None = None ,
183- ) -> tuple [list [LegacyRow ], list [str ], dict [int , str ]]:
181+ ) -> tuple [list [RowMapping ], list [str ], dict [int , str ]]:
184182 LOG .debug (f"DB operation: fetch_from_db_threaded ({ len (queries )} ) on { 'Target' if use_target_db else 'App' } database (User type = normal)" )
185183
186- def fetch_data (query : str , params : dict | None , index : int ) -> tuple [list [LegacyRow ], list [str ], int , str | None ]:
184+ def fetch_data (query : str , params : dict | None , index : int ) -> tuple [list [RowMapping ], list [str ], int , str | None ]:
187185 LOG .debug (f"Query: { query } " )
188- row_data : list [LegacyRow ] = []
186+ row_data : list [RowMapping ] = []
189187 column_names : list [str ] = []
190188 error = None
191189
192190 try :
193191 with _init_db_connection (use_target_db ) as connection :
194192 result = connection .execute (text (query ), params )
195193 LOG .debug (f"{ result .rowcount } records retrieved" )
196- row_data = result .fetchall ()
194+ row_data = result .mappings (). fetchall ()
197195 column_names = list (result .keys ())
198196 except Exception as e :
199197 error = get_exception_message (e )
200198 LOG .exception (f"Failed to execute threaded query: { query } " )
201199
202200 return row_data , column_names , index , error
203201
204- result_data : list [LegacyRow ] = []
202+ result_data : list [RowMapping ] = []
205203 result_columns : list [str ] = []
206204 error_data : dict [int , str ] = {}
207205
@@ -241,7 +239,7 @@ def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[Legacy
241239
242240def fetch_list_from_db (
243241 query : str , params : dict | None = None , use_target_db : bool = False
244- ) -> tuple [list [LegacyRow ], list [str ]]:
242+ ) -> tuple [list [Row ], list [str ]]:
245243 LOG .debug (f"DB operation: fetch_list_from_db on { 'Target' if use_target_db else 'App' } database (User type = normal)" )
246244
247245 with _init_db_connection (use_target_db ) as connection :
@@ -263,21 +261,29 @@ def fetch_dict_from_db(
263261 LOG .debug (f"Query: { query } " )
264262 result = connection .execute (text (query ), params )
265263 LOG .debug (f"{ result .rowcount } records retrieved" )
266- # Creates list of dictionaries so records are addressible by column name
267- return [row ._mapping for row in result ]
264+ return result .mappings ().all ()
268265
269266
270- def write_to_app_db (data : list [LegacyRow ], column_names : Iterable [str ], table_name : str ) -> None :
267+ def write_to_app_db (data : list [Row ], column_names : Iterable [str ], table_name : str ) -> None :
271268 LOG .debug ("DB operation: write_to_app_db on App database (User type = normal)" )
272269
273270 # use_raw is required to make use of the copy_expert method for fast batch ingestion
274271 connection = _init_db_connection (use_raw = True )
275272 cursor = connection .cursor ()
276273
277274 # Write List to CSV in memory
275+ # Sanitize NaN → None: some DB connectors (e.g. Databricks via Arrow) return
276+ # float('nan') for NULL integers. CSV would serialize these as "nan" which
277+ # PostgreSQL rejects for numeric columns.
278+ # RowMapping objects iterate over keys, not values — extract values explicitly.
279+ def _row_values (row ):
280+ values = row .values () if isinstance (row , RowMapping ) else row
281+ return tuple (None if isinstance (v , float ) and math .isnan (v ) else v for v in values )
282+
283+ sanitized = [_row_values (row ) for row in data ]
278284 buffer = FilteredStringIO (["\x00 " ])
279285 writer = csv .writer (buffer , quoting = csv .QUOTE_MINIMAL )
280- writer .writerows (data )
286+ writer .writerows (sanitized )
281287 buffer .seek (0 )
282288
283289 # List should have same column names as destination table, though not all columns in table are required
@@ -362,7 +368,7 @@ def _init_app_db_connection(
362368 password_override : str | None = None ,
363369 user_type : UserType = "normal" ,
364370 use_raw : bool = False ,
365- ) -> Connection | _ConnectionFairy :
371+ ) -> Connection | PoolProxiedConnection :
366372 database_name = "postgres" if user_type == "database_admin" else get_tg_db ()
367373 is_admin = user_type == "database_admin" or user_type == "schema_admin"
368374
@@ -399,7 +405,7 @@ def _init_app_db_connection(
399405 try :
400406 schema_name = "public" if is_admin else get_tg_schema ()
401407 if use_raw :
402- connection : _ConnectionFairy = engine .raw_connection ()
408+ connection : PoolProxiedConnection = engine .raw_connection ()
403409 with connection .cursor () as cursor :
404410 cursor .execute (
405411 "SET SEARCH_PATH = %(schema_name)s" ,
0 commit comments