1+ import contextlib
12import typing as t
23from collections import deque
34from typing import Any , Optional , Tuple , Union
@@ -39,7 +40,9 @@ class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor):
3940 _adapt_connection : "AsyncAdapt_psqlpy_connection"
4041 _connection : psqlpy .Connection
4142
42- def __init__ (self , adapt_connection : AsyncAdapt_dbapi_connection ) -> None :
43+ def __init__ (
44+ self , adapt_connection : "AsyncAdapt_psqlpy_connection"
45+ ) -> None :
4346 self ._adapt_connection = adapt_connection
4447 self ._connection = adapt_connection ._connection
4548 self ._rows : deque [t .Any ] = deque ()
@@ -177,10 +180,9 @@ def process_value(value: Any) -> Any:
177180 return {
178181 key : process_value (value ) for key , value in parameters .items ()
179182 }
180- elif isinstance (parameters , (list , tuple )):
183+ if isinstance (parameters , (list , tuple )):
181184 return [process_value (value ) for value in parameters ]
182- else :
183- return process_value (parameters )
185+ return process_value (parameters )
184186
185187 def _convert_named_params_with_casting (
186188 self ,
@@ -330,9 +332,21 @@ async def _executemany(
330332 self ._process_parameters (params ) for params in seq_of_parameters
331333 ]
332334
335+ # Convert to the expected type for execute_many
336+ converted_seq : t .List [t .List [t .Any ]] = []
337+ for params in processed_seq :
338+ if params is None :
339+ converted_seq .append ([])
340+ elif isinstance (params , dict ):
341+ converted_seq .append (list (params .values ()))
342+ elif isinstance (params , (list , tuple )):
343+ converted_seq .append (list (params ))
344+ else :
345+ converted_seq .append ([params ])
346+
333347 return await self ._connection .execute_many (
334348 operation ,
335- processed_seq ,
349+ converted_seq ,
336350 True ,
337351 )
338352
@@ -360,7 +374,7 @@ class AsyncAdapt_psqlpy_ss_cursor(
360374):
361375 """Enhanced server-side cursor with better async iteration support"""
362376
363- _cursor : psqlpy .Cursor
377+ _cursor : t . Optional [ psqlpy .Cursor ]
364378
365379 def __init__ (
366380 self , adapt_connection : "AsyncAdapt_psqlpy_connection"
@@ -377,7 +391,7 @@ def _convert_result(
377391 ) -> Tuple [Tuple [Any , ...], ...]:
378392 """Enhanced result conversion with better error handling"""
379393 if result is None :
380- return tuple ()
394+ return ()
381395
382396 try :
383397 return tuple (
@@ -386,7 +400,7 @@ def _convert_result(
386400 )
387401 except Exception :
388402 # Return empty tuple on conversion error
389- return tuple ()
403+ return ()
390404
391405 def close (self ) -> None :
392406 """Enhanced close with proper state management"""
@@ -456,6 +470,7 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection):
456470 _ss_cursor_cls = AsyncAdapt_psqlpy_ss_cursor
457471
458472 _connection : psqlpy .Connection
473+ _transaction : t .Optional [psqlpy .Transaction ]
459474
460475 __slots__ = (
461476 "_invalidate_schema_cache_asof" ,
@@ -513,7 +528,7 @@ def rollback(self) -> None:
513528 if self ._transaction is not None :
514529 await_only (self ._transaction .rollback ())
515530 else :
516- await_only (self ._connection .rollback ())
531+ await_only (self ._connection .rollback ()) # type: ignore[attr-defined]
517532 self ._performance_stats ["transactions_rolled_back" ] += 1
518533 except Exception :
519534 self ._performance_stats ["connection_errors" ] += 1
@@ -530,16 +545,14 @@ def commit(self) -> None:
530545 if self ._transaction is not None :
531546 await_only (self ._transaction .commit ())
532547 else :
533- await_only (self ._connection .commit ())
548+ await_only (self ._connection .commit ()) # type: ignore[attr-defined]
534549 self ._performance_stats ["transactions_committed" ] += 1
535550 except Exception as e :
536551 self ._performance_stats ["connection_errors" ] += 1
537552 self ._connection_valid = False
538553 # On commit failure, try to rollback
539- try :
554+ with contextlib . suppress ( Exception ) :
540555 self .rollback ()
541- except Exception :
542- pass
543556 raise e
544557 finally :
545558 self ._transaction = None
@@ -549,7 +562,7 @@ def is_valid(self) -> bool:
549562 """Check if connection is valid"""
550563 return self ._connection_valid and self ._connection is not None
551564
552- def ping (self ) -> bool :
565+ def ping (self , reconnect : t . Any = None ) -> t . Any :
553566 """Ping the connection to check if it's alive"""
554567 import time
555568
0 commit comments