Skip to content

Commit c9d182b

Browse files
author
Eugene Shershen
committed
add type annotations and update mypy configuration
1 parent 55a8f5c commit c9d182b

4 files changed

Lines changed: 94 additions & 63 deletions

File tree

psqlpy_sqlalchemy/connection.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing as t
22
from collections import deque
3-
from typing import Any, Optional, Tuple
3+
from typing import Any, Optional, Tuple, Union
44

55
import psqlpy
66
from psqlpy import row_factories
@@ -26,19 +26,23 @@ def create_server_side_cursor(self) -> "DBAPICursor":
2626

2727
class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor):
2828
__slots__ = (
29+
"_adapt_connection",
2930
"_arraysize",
31+
"_connection",
32+
"_cursor",
3033
"_description",
3134
"_invalidate_schema_cache_asof",
3235
"_rowcount",
36+
"_rows",
3337
)
3438

3539
_adapt_connection: "AsyncAdapt_psqlpy_connection"
3640
_connection: psqlpy.Connection
3741

38-
def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
42+
def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection) -> None:
3943
self._adapt_connection = adapt_connection
4044
self._connection = adapt_connection._connection
41-
self._rows = deque()
45+
self._rows: deque[t.Any] = deque()
4246
self._description: t.Optional[t.List[t.Tuple[t.Any, ...]]] = None
4347
self._arraysize = 1
4448
self._rowcount = -1
@@ -155,7 +159,7 @@ def _process_parameters(
155159

156160
import uuid
157161

158-
def process_value(value):
162+
def process_value(value: Any) -> Any:
159163
"""Process a single parameter value."""
160164
if value is None:
161165
return None
@@ -341,10 +345,12 @@ def execute(
341345
) -> None:
342346
await_only(self._prepare_execute(operation, parameters))
343347

344-
def executemany(self, operation, seq_of_parameters) -> None:
348+
def executemany(
349+
self, operation: t.Any, seq_of_parameters: t.Sequence[t.Any]
350+
) -> None:
345351
return await_only(self._executemany(operation, seq_of_parameters))
346352

347-
def setinputsizes(self, *inputsizes):
353+
def setinputsizes(self, *inputsizes: t.Any) -> None:
348354
raise NotImplementedError
349355

350356

@@ -356,7 +362,9 @@ class AsyncAdapt_psqlpy_ss_cursor(
356362

357363
_cursor: psqlpy.Cursor
358364

359-
def __init__(self, adapt_connection):
365+
def __init__(
366+
self, adapt_connection: "AsyncAdapt_psqlpy_connection"
367+
) -> None:
360368
self._adapt_connection = adapt_connection
361369
self._connection = adapt_connection._connection
362370
self.await_ = adapt_connection.await_
@@ -380,7 +388,7 @@ def _convert_result(
380388
# Return empty tuple on conversion error
381389
return tuple()
382390

383-
def close(self):
391+
def close(self) -> None:
384392
"""Enhanced close with proper state management"""
385393
if self._cursor is not None and not self._closed:
386394
try:
@@ -392,7 +400,7 @@ def close(self):
392400
self._cursor = None
393401
self._closed = True
394402

395-
def fetchone(self):
403+
def fetchone(self) -> Optional[Tuple[Any, ...]]:
396404
"""Fetch one row with enhanced error handling"""
397405
if self._closed or self._cursor is None:
398406
return None
@@ -404,7 +412,7 @@ def fetchone(self):
404412
except Exception:
405413
return None
406414

407-
def fetchmany(self, size=None):
415+
def fetchmany(self, size: Optional[int] = None) -> t.List[Tuple[Any, ...]]:
408416
"""Fetch many rows with enhanced error handling"""
409417
if self._closed or self._cursor is None:
410418
return []
@@ -417,7 +425,7 @@ def fetchmany(self, size=None):
417425
except Exception:
418426
return []
419427

420-
def fetchall(self):
428+
def fetchall(self) -> t.List[Tuple[Any, ...]]:
421429
"""Fetch all rows with enhanced error handling"""
422430
if self._closed or self._cursor is None:
423431
return []
@@ -428,7 +436,7 @@ def fetchall(self):
428436
except Exception:
429437
return []
430438

431-
def __iter__(self):
439+
def __iter__(self) -> t.Iterator[Tuple[Any, ...]]:
432440
if self._closed or self._cursor is None:
433441
return
434442

@@ -464,15 +472,15 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection):
464472
"readonly",
465473
)
466474

467-
def __init__(self, dbapi, connection):
475+
def __init__(self, dbapi: t.Any, connection: psqlpy.Connection) -> None:
468476
super().__init__(dbapi, connection)
469477
self.isolation_level = self._isolation_setting = None
470478
self.readonly = False
471479
self.deferrable = False
472480
self._transaction = None
473481
self._started = False
474482
self._connection_valid = True
475-
self._last_ping_time = 0
483+
self._last_ping_time = 0.0
476484
self._performance_stats = {
477485
"queries_executed": 0,
478486
"transactions_committed": 0,
@@ -496,7 +504,7 @@ async def _start_transaction(self) -> None:
496504
self._started = False
497505
raise
498506

499-
def set_isolation_level(self, level):
507+
def set_isolation_level(self, level: t.Any) -> None:
500508
self.isolation_level = self._isolation_setting = level
501509

502510
def rollback(self) -> None:
@@ -561,7 +569,7 @@ def ping(self) -> bool:
561569
self._performance_stats["connection_errors"] += 1
562570
return False
563571

564-
def get_performance_stats(self) -> dict:
572+
def get_performance_stats(self) -> t.Dict[str, int]:
565573
"""Get connection performance statistics"""
566574
return self._performance_stats.copy()
567575

@@ -574,11 +582,13 @@ def reset_performance_stats(self) -> None:
574582
"connection_errors": 0,
575583
}
576584

577-
def close(self):
585+
def close(self) -> None:
578586
self.rollback()
579587
self._connection.close()
580588

581-
def cursor(self, server_side=False):
589+
def cursor(
590+
self, server_side: bool = False
591+
) -> Union[AsyncAdapt_psqlpy_cursor, AsyncAdapt_psqlpy_ss_cursor]:
582592
if server_side:
583593
return self._ss_cursor_cls(self)
584594
return self._cursor_cls(self)

psqlpy_sqlalchemy/dbapi.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import typing as t
2+
13
from sqlalchemy.util.concurrency import await_only
24

35
from .connection import AsyncAdapt_psqlpy_connection
46

57

68
class PSQLPyAdaptDBAPI:
7-
def __init__(self, psqlpy) -> None:
9+
def __init__(self, psqlpy: t.Any) -> None:
810
self.psqlpy = psqlpy
911
self.paramstyle = "numeric_dollar"
1012

@@ -29,7 +31,9 @@ def __init__(self, psqlpy) -> None:
2931
if k != "connect":
3032
self.__dict__[k] = v
3133

32-
def connect(self, *arg, **kw):
34+
def connect(
35+
self, *arg: t.Any, **kw: t.Any
36+
) -> AsyncAdapt_psqlpy_connection:
3337
creator_fn = kw.pop("async_creator_fn", self.psqlpy.connect)
3438

3539
# Handle server_settings parameter that SQLAlchemy might pass
@@ -88,7 +92,7 @@ class PsqlpyDBAPI:
8892
"numeric_dollar" # PostgreSQL uses $1, $2, etc. style parameters
8993
)
9094

91-
def __init__(self):
95+
def __init__(self) -> None:
9296
# Initialize with psqlpy module
9397
import psqlpy
9498

@@ -108,44 +112,52 @@ def __init__(self):
108112
self.NotSupportedError = _error_class
109113

110114
# Type constructors
111-
def Date(self, year, month, day):
115+
def Date(self, year: int, month: int, day: int) -> t.Any:
112116
"""Construct a date value"""
113117
import datetime
114118

115119
return datetime.date(year, month, day)
116120

117-
def Time(self, hour, minute, second):
121+
def Time(self, hour: int, minute: int, second: int) -> t.Any:
118122
"""Construct a time value"""
119123
import datetime
120124

121125
return datetime.time(hour, minute, second)
122126

123-
def Timestamp(self, year, month, day, hour, minute, second):
127+
def Timestamp(
128+
self,
129+
year: int,
130+
month: int,
131+
day: int,
132+
hour: int,
133+
minute: int,
134+
second: int,
135+
) -> t.Any:
124136
"""Construct a timestamp value"""
125137
import datetime
126138

127139
return datetime.datetime(year, month, day, hour, minute, second)
128140

129-
def DateFromTicks(self, ticks):
141+
def DateFromTicks(self, ticks: float) -> t.Any:
130142
"""Construct a date from ticks"""
131143
import datetime
132144

133145
return datetime.date.fromtimestamp(ticks)
134146

135-
def TimeFromTicks(self, ticks):
147+
def TimeFromTicks(self, ticks: float) -> t.Any:
136148
"""Construct a time from ticks"""
137149
import datetime
138150

139151
dt = datetime.datetime.fromtimestamp(ticks)
140152
return dt.time()
141153

142-
def TimestampFromTicks(self, ticks):
154+
def TimestampFromTicks(self, ticks: float) -> t.Any:
143155
"""Construct a timestamp from ticks"""
144156
import datetime
145157

146158
return datetime.datetime.fromtimestamp(ticks)
147159

148-
def Binary(self, string):
160+
def Binary(self, string: t.Union[str, bytes]) -> bytes:
149161
"""Construct a binary value"""
150162
if isinstance(string, str):
151163
return string.encode("utf-8")
@@ -158,6 +170,8 @@ def Binary(self, string):
158170
DATETIME = object # datetime objects
159171
ROWID = int
160172

161-
def connect(self, *args, **kwargs):
173+
def connect(
174+
self, *args: t.Any, **kwargs: t.Any
175+
) -> AsyncAdapt_psqlpy_connection:
162176
"""Create a connection - delegates to the adapted DBAPI"""
163177
return self._adapt_dbapi.connect(*args, **kwargs)

psqlpy_sqlalchemy/dialect.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ class CompatibleNullPool(NullPool):
2424
but are commonly passed by frameworks like FastAPI with fastapi_async_sqlalchemy.
2525
"""
2626

27-
def __init__(self, creator, pool_size=None, max_overflow=None, **kw):
27+
def __init__(
28+
self,
29+
creator: t.Any,
30+
pool_size: t.Optional[int] = None,
31+
max_overflow: t.Optional[int] = None,
32+
**kw: t.Any,
33+
) -> None:
2834
# Filter out pool sizing arguments that NullPool doesn't accept
2935
filtered_kw = {
3036
k: v
@@ -140,43 +146,43 @@ class _PGJSONB(sqltypes.JSON):
140146
class Comparator(sqltypes.JSON.Comparator):
141147
"""Enhanced comparator with JSONB-specific operators"""
142148

143-
def contains(self, other):
149+
def contains(self, other: t.Any) -> t.Any:
144150
"""JSONB containment operator @>"""
145151
return self.operate(operators.custom_op("@>"), other)
146152

147-
def contained_by(self, other):
153+
def contained_by(self, other: t.Any) -> t.Any:
148154
"""JSONB contained by operator <@"""
149155
return self.operate(operators.custom_op("<@"), other)
150156

151-
def has_key(self, key):
157+
def has_key(self, key: t.Any) -> t.Any:
152158
"""JSONB has key operator ?"""
153159
return self.operate(operators.custom_op("?"), key)
154160

155-
def has_any_key(self, keys):
161+
def has_any_key(self, keys: t.Any) -> t.Any:
156162
"""JSONB has any key operator ?|"""
157163
return self.operate(operators.custom_op("?|"), keys)
158164

159-
def has_all_keys(self, keys):
165+
def has_all_keys(self, keys: t.Any) -> t.Any:
160166
"""JSONB has all keys operator ?&"""
161167
return self.operate(operators.custom_op("?&"), keys)
162168

163-
def path_exists(self, path):
169+
def path_exists(self, path: t.Any) -> t.Any:
164170
"""JSONB path exists operator @?"""
165171
return self.operate(operators.custom_op("@?"), path)
166172

167-
def path_match(self, path):
173+
def path_match(self, path: t.Any) -> t.Any:
168174
"""JSONB path match operator @@"""
169175
return self.operate(operators.custom_op("@@"), path)
170176

171-
def concat(self, other):
177+
def concat(self, other: t.Any) -> t.Any:
172178
"""JSONB concatenation operator ||"""
173179
return self.operate(operators.custom_op("||"), other)
174180

175-
def delete_key(self, key):
181+
def delete_key(self, key: t.Any) -> t.Any:
176182
"""JSONB delete key operator -"""
177183
return self.operate(operators.custom_op("-"), key)
178184

179-
def delete_path(self, path):
185+
def delete_path(self, path: t.Any) -> t.Any:
180186
"""JSONB delete path operator #-"""
181187
return self.operate(operators.custom_op("#-"), path)
182188

@@ -222,10 +228,12 @@ class _PGNullType(sqltypes.NullType):
222228
class _PGUUID(UUID):
223229
"""PostgreSQL UUID type with proper parameter binding for psqlpy."""
224230

225-
def bind_processor(self, dialect):
231+
def bind_processor(
232+
self, dialect: t.Any
233+
) -> t.Optional[t.Callable[[t.Any], t.Any]]:
226234
"""Process UUID parameters for psqlpy compatibility."""
227235

228-
def process(value):
236+
def process(value: t.Any) -> t.Optional[bytes]:
229237
if value is None:
230238
return None
231239
if isinstance(value, uuid.UUID):
@@ -332,23 +340,23 @@ def create_connect_args(
332340
def set_isolation_level(
333341
self,
334342
dbapi_connection: AsyncAdapt_psqlpy_connection,
335-
level,
336-
):
343+
level: t.Any,
344+
) -> None:
337345
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
338346

339-
def set_readonly(self, connection, value):
347+
def set_readonly(self, connection: t.Any, value: t.Any) -> None:
340348
if value is True:
341349
connection.readonly = psqlpy.ReadVariant.ReadOnly
342350
else:
343351
connection.readonly = psqlpy.ReadVariant.ReadWrite
344352

345-
def get_readonly(self, connection):
353+
def get_readonly(self, connection: t.Any) -> t.Any:
346354
return connection.readonly
347355

348-
def set_deferrable(self, connection, value):
356+
def set_deferrable(self, connection: t.Any, value: t.Any) -> None:
349357
connection.deferrable = value
350358

351-
def get_deferrable(self, connection):
359+
def get_deferrable(self, connection: t.Any) -> t.Any:
352360
return connection.deferrable
353361

354362

0 commit comments

Comments
 (0)