Skip to content

Commit 8bcd9e7

Browse files
author
Eugene Shershen
committed
refactor test cases and type annotations; update version to 0.1.0a10
1 parent c9d182b commit 8bcd9e7

9 files changed

Lines changed: 180 additions & 165 deletions

File tree

psqlpy_sqlalchemy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
PsqlpyDialect = PSQLPyAsyncDialect
44

5-
__version__ = "0.1.0a9"
5+
__version__ = "0.1.0a10"
66
__all__ = ["PsqlpyDialect", "PSQLPyAsyncDialect"]

psqlpy_sqlalchemy/connection.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import typing as t
23
from collections import deque
34
from typing import Any, Optional, Tuple, Union
@@ -39,7 +40,9 @@ class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor):
3940
_adapt_connection: "AsyncAdapt_psqlpy_connection"
4041
_connection: psqlpy.Connection
4142

42-
def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection) -> None:
43+
def __init__(
44+
self, adapt_connection: "AsyncAdapt_psqlpy_connection"
45+
) -> None:
4346
self._adapt_connection = adapt_connection
4447
self._connection = adapt_connection._connection
4548
self._rows: deque[t.Any] = deque()
@@ -177,10 +180,9 @@ def process_value(value: Any) -> Any:
177180
return {
178181
key: process_value(value) for key, value in parameters.items()
179182
}
180-
elif isinstance(parameters, (list, tuple)):
183+
if isinstance(parameters, (list, tuple)):
181184
return [process_value(value) for value in parameters]
182-
else:
183-
return process_value(parameters)
185+
return process_value(parameters)
184186

185187
def _convert_named_params_with_casting(
186188
self,
@@ -330,9 +332,21 @@ async def _executemany(
330332
self._process_parameters(params) for params in seq_of_parameters
331333
]
332334

335+
# Convert to the expected type for execute_many
336+
converted_seq: t.List[t.List[t.Any]] = []
337+
for params in processed_seq:
338+
if params is None:
339+
converted_seq.append([])
340+
elif isinstance(params, dict):
341+
converted_seq.append(list(params.values()))
342+
elif isinstance(params, (list, tuple)):
343+
converted_seq.append(list(params))
344+
else:
345+
converted_seq.append([params])
346+
333347
return await self._connection.execute_many(
334348
operation,
335-
processed_seq,
349+
converted_seq,
336350
True,
337351
)
338352

@@ -360,7 +374,7 @@ class AsyncAdapt_psqlpy_ss_cursor(
360374
):
361375
"""Enhanced server-side cursor with better async iteration support"""
362376

363-
_cursor: psqlpy.Cursor
377+
_cursor: t.Optional[psqlpy.Cursor]
364378

365379
def __init__(
366380
self, adapt_connection: "AsyncAdapt_psqlpy_connection"
@@ -377,7 +391,7 @@ def _convert_result(
377391
) -> Tuple[Tuple[Any, ...], ...]:
378392
"""Enhanced result conversion with better error handling"""
379393
if result is None:
380-
return tuple()
394+
return ()
381395

382396
try:
383397
return tuple(
@@ -386,7 +400,7 @@ def _convert_result(
386400
)
387401
except Exception:
388402
# Return empty tuple on conversion error
389-
return tuple()
403+
return ()
390404

391405
def close(self) -> None:
392406
"""Enhanced close with proper state management"""
@@ -456,6 +470,7 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection):
456470
_ss_cursor_cls = AsyncAdapt_psqlpy_ss_cursor
457471

458472
_connection: psqlpy.Connection
473+
_transaction: t.Optional[psqlpy.Transaction]
459474

460475
__slots__ = (
461476
"_invalidate_schema_cache_asof",
@@ -513,7 +528,7 @@ def rollback(self) -> None:
513528
if self._transaction is not None:
514529
await_only(self._transaction.rollback())
515530
else:
516-
await_only(self._connection.rollback())
531+
await_only(self._connection.rollback()) # type: ignore[attr-defined]
517532
self._performance_stats["transactions_rolled_back"] += 1
518533
except Exception:
519534
self._performance_stats["connection_errors"] += 1
@@ -530,16 +545,14 @@ def commit(self) -> None:
530545
if self._transaction is not None:
531546
await_only(self._transaction.commit())
532547
else:
533-
await_only(self._connection.commit())
548+
await_only(self._connection.commit()) # type: ignore[attr-defined]
534549
self._performance_stats["transactions_committed"] += 1
535550
except Exception as e:
536551
self._performance_stats["connection_errors"] += 1
537552
self._connection_valid = False
538553
# On commit failure, try to rollback
539-
try:
554+
with contextlib.suppress(Exception):
540555
self.rollback()
541-
except Exception:
542-
pass
543556
raise e
544557
finally:
545558
self._transaction = None
@@ -549,7 +562,7 @@ def is_valid(self) -> bool:
549562
"""Check if connection is valid"""
550563
return self._connection_valid and self._connection is not None
551564

552-
def ping(self) -> bool:
565+
def ping(self, reconnect: t.Any = None) -> t.Any:
553566
"""Ping the connection to check if it's alive"""
554567
import time
555568

psqlpy_sqlalchemy/dbapi.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def connect(
3838

3939
# Handle server_settings parameter that SQLAlchemy might pass
4040
server_settings = kw.pop("server_settings", None)
41-
if server_settings:
42-
# Map server_settings to individual psqlpy parameters
43-
if "application_name" in server_settings:
44-
kw["application_name"] = server_settings["application_name"]
41+
if server_settings and "application_name" in server_settings:
42+
kw["application_name"] = server_settings["application_name"]
4543
# Add other server_settings mappings as needed
4644

4745
# Filter out any other unsupported parameters that SQLAlchemy might pass

psqlpy_sqlalchemy/dialect.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy import URL, util
88
from sqlalchemy.dialects.postgresql.base import INTERVAL, UUID, PGDialect
99
from sqlalchemy.dialects.postgresql.json import JSONPathType
10+
from sqlalchemy.engine.interfaces import DBAPIConnection
1011
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
1112
from sqlalchemy.sql import operators, sqltypes
1213
from sqlalchemy.sql.functions import GenericFunction
@@ -41,80 +42,80 @@ def __init__(
4142

4243

4344
# JSONB aggregation functions
44-
class jsonb_agg(GenericFunction):
45+
class jsonb_agg(GenericFunction[t.Any]):
4546
"""JSONB aggregation function"""
4647

47-
type = sqltypes.JSON
48+
type_ = sqltypes.JSON
4849
name = "jsonb_agg"
4950

5051

51-
class jsonb_object_agg(GenericFunction):
52+
class jsonb_object_agg(GenericFunction[t.Any]):
5253
"""JSONB object aggregation function"""
5354

54-
type = sqltypes.JSON
55+
type_ = sqltypes.JSON
5556
name = "jsonb_object_agg"
5657

5758

58-
class jsonb_build_array(GenericFunction):
59+
class jsonb_build_array(GenericFunction[t.Any]):
5960
"""JSONB build array function"""
6061

61-
type = sqltypes.JSON
62+
type_ = sqltypes.JSON
6263
name = "jsonb_build_array"
6364

6465

65-
class jsonb_build_object(GenericFunction):
66+
class jsonb_build_object(GenericFunction[t.Any]):
6667
"""JSONB build object function"""
6768

68-
type = sqltypes.JSON
69+
type_ = sqltypes.JSON
6970
name = "jsonb_build_object"
7071

7172

72-
class jsonb_extract_path(GenericFunction):
73+
class jsonb_extract_path(GenericFunction[t.Any]):
7374
"""JSONB extract path function"""
7475

75-
type = sqltypes.JSON
76+
type_ = sqltypes.JSON
7677
name = "jsonb_extract_path"
7778

7879

79-
class jsonb_extract_path_text(GenericFunction):
80+
class jsonb_extract_path_text(GenericFunction[t.Any]):
8081
"""JSONB extract path as text function"""
8182

82-
type = sqltypes.Text
83+
type_ = sqltypes.Text
8384
name = "jsonb_extract_path_text"
8485

8586

86-
class jsonb_path_exists(GenericFunction):
87+
class jsonb_path_exists(GenericFunction[t.Any]):
8788
"""JSONB path exists function"""
8889

89-
type = sqltypes.Boolean
90+
type_ = sqltypes.Boolean
9091
name = "jsonb_path_exists"
9192

9293

93-
class jsonb_path_match(GenericFunction):
94+
class jsonb_path_match(GenericFunction[t.Any]):
9495
"""JSONB path match function"""
9596

96-
type = sqltypes.Boolean
97+
type_ = sqltypes.Boolean
9798
name = "jsonb_path_match"
9899

99100

100-
class jsonb_path_query(GenericFunction):
101+
class jsonb_path_query(GenericFunction[t.Any]):
101102
"""JSONB path query function"""
102103

103-
type = sqltypes.JSON
104+
type_ = sqltypes.JSON
104105
name = "jsonb_path_query"
105106

106107

107-
class jsonb_path_query_array(GenericFunction):
108+
class jsonb_path_query_array(GenericFunction[t.Any]):
108109
"""JSONB path query array function"""
109110

110-
type = sqltypes.JSON
111+
type_ = sqltypes.JSON
111112
name = "jsonb_path_query_array"
112113

113114

114-
class jsonb_path_query_first(GenericFunction):
115+
class jsonb_path_query_first(GenericFunction[t.Any]):
115116
"""JSONB path query first function"""
116117

117-
type = sqltypes.JSON
118+
type_ = sqltypes.JSON
118119
name = "jsonb_path_query_first"
119120

120121

@@ -143,10 +144,10 @@ class _PGJSONB(sqltypes.JSON):
143144
__visit_name__ = "JSONB"
144145
render_bind_cast = True
145146

146-
class Comparator(sqltypes.JSON.Comparator):
147+
class Comparator(sqltypes.JSON.Comparator[t.Any]):
147148
"""Enhanced comparator with JSONB-specific operators"""
148149

149-
def contains(self, other: t.Any) -> t.Any:
150+
def contains(self, other: t.Any, **kw: t.Any) -> t.Any:
150151
"""JSONB containment operator @>"""
151152
return self.operate(operators.custom_op("@>"), other)
152153

@@ -225,7 +226,7 @@ class _PGNullType(sqltypes.NullType):
225226
render_bind_cast = True
226227

227228

228-
class _PGUUID(UUID):
229+
class _PGUUID(UUID[t.Any]):
229230
"""PostgreSQL UUID type with proper parameter binding for psqlpy."""
230231

231232
def bind_processor(
@@ -339,10 +340,13 @@ def create_connect_args(
339340

340341
def set_isolation_level(
341342
self,
342-
dbapi_connection: AsyncAdapt_psqlpy_connection,
343+
dbapi_connection: DBAPIConnection,
343344
level: t.Any,
344345
) -> None:
345-
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
346+
psqlpy_connection = t.cast(
347+
AsyncAdapt_psqlpy_connection, dbapi_connection
348+
)
349+
psqlpy_connection.set_isolation_level(self._isolation_lookup[level])
346350

347351
def set_readonly(self, connection: t.Any, value: t.Any) -> None:
348352
if value is True:

pyproject.toml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "psqlpy-sqlalchemy"
7-
version = "0.1.0a9"
7+
version = "0.1.0a10"
88
description = "SQLAlchemy dialect for psqlpy PostgreSQL driver"
99
readme = "README.md"
1010
license = {text = "MIT"}
@@ -68,14 +68,23 @@ target-version = "py38"
6868

6969
[tool.ruff.lint]
7070
select = [
71-
"E", # pycodestyle errors
72-
"W", # pycodestyle warnings
73-
"F", # pyflakes
74-
"I", # isort
75-
"UP", # pyupgrade
71+
"UP", # pyupgrade
72+
"E", # pycodestyle errors
73+
"W", # pycodestyle warnings
74+
"F", # pyflakes
75+
"I", # isort
76+
"C", # flake8-comprehensions
77+
"B", # flake8-bugbear
78+
"PTH", # flake8-use-pathlib
79+
"ASYNC", # flake8-async
80+
"SIM", # flake8-simplify
81+
"RET", # flake8-return
7682
]
7783
ignore = [
78-
"E501", # line too long (handled by formatter)
84+
"E501", # line too long
85+
"C901", # too complex
86+
"B008", # do not perform function calls in argument defaults
87+
"B904", # Within an `except` clause, raise exceptions with `raise ... from err`
7988
]
8089

8190
[tool.ruff.format]
@@ -108,6 +117,7 @@ ignore_missing_imports = true
108117

109118
[[tool.mypy.overrides]]
110119
module = "psqlpy_sqlalchemy.*"
120+
warn_unused_ignores = false
111121

112122
[[tool.mypy.overrides]]
113123
module = "tests.*"

0 commit comments

Comments
 (0)