Skip to content

Commit 33c31ac

Browse files
febus982claude
andauthored
Other code improvements (#85)
* Remove duplicate license header in abstract.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove unnecessary partial wrapper in order_by registry The partial() wrapper was not needed since asc and desc can be referenced directly as callables. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove deprecated future=True engine option The future=True option is deprecated in SQLAlchemy 2.0 and is now the default behavior. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove redundant or [] fallback in cursor_paginated_find The list comprehension already produces a list, making the or [] fallback unnecessary. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add exception handling to SessionHandler.__del__ The scoped_session.remove() call could fail if the database connection is already closed during garbage collection. Wrapping in try/except with debug logging prevents errors during interpreter shutdown. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Use any() for more efficient model validation Replaces list comprehension with any() generator expression to short-circuit on first invalid model instead of iterating all items. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Extract shared primary key utility function Consolidates duplicated PK inspection logic from BaseRepository._model_pk() and result_presenters._pk_from_result_object() into a single get_model_pk_name() function in common.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add test for SessionHandler.__del__ exception handling Verifies that exceptions from scoped_session.remove() are caught and logged during cleanup, preventing errors during garbage collection. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix type annotation for order_by function registry Use Callable instead of type since asc/desc are functions, not types. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c2a9ad2 commit 33c31ac

File tree

9 files changed

+52
-40
lines changed

9 files changed

+52
-40
lines changed

sqlalchemy_bind_manager/_bind_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init_bind(self, name: str, config: SQLAlchemyConfig):
115115

116116
engine_options: dict = config.engine_options or {}
117117
engine_options.setdefault("echo", False)
118-
engine_options.setdefault("future", True)
119118

120119
session_options: dict = config.session_options or {}
121120
session_options.setdefault("expire_on_commit", False)

sqlalchemy_bind_manager/_repository/abstract.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,6 @@
1818
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1919
# DEALINGS IN THE SOFTWARE.
2020

21-
#
22-
# Permission is hereby granted, free of charge, to any person obtaining a
23-
# copy of this software and associated documentation files (the "Software"),
24-
# to deal in the Software without restriction, including without limitation
25-
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
26-
# and/or sell copies of the Software, and to permit persons to whom the
27-
# Software is furnished to do so, subject to the following conditions:
28-
#
29-
#
3021
from typing import (
3122
Any,
3223
Iterable,

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ async def cursor_paginated_find(
274274
).scalar() or 0
275275
result_items = [
276276
x for x in (await session.execute(paginated_stmt)).scalars()
277-
] or []
277+
]
278278

279279
return CursorPaginatedResultPresenter.build_result(
280280
result_items=result_items,

sqlalchemy_bind_manager/_repository/base_repository.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# DEALINGS IN THE SOFTWARE.
2020

2121
from abc import ABC
22-
from functools import partial
2322
from typing import (
2423
Any,
2524
Callable,
@@ -33,7 +32,7 @@
3332
Union,
3433
)
3534

36-
from sqlalchemy import asc, desc, func, inspect, select
35+
from sqlalchemy import asc, desc, func, select
3736
from sqlalchemy.orm import Mapper, aliased, class_mapper, lazyload
3837
from sqlalchemy.orm.exc import UnmappedClassError
3938
from sqlalchemy.sql import Select
@@ -43,6 +42,7 @@
4342
from .common import (
4443
MODEL,
4544
CursorReference,
45+
get_model_pk_name,
4646
)
4747

4848

@@ -131,9 +131,9 @@ def _filter_order_by(
131131
:param order_by: a list of columns, or tuples (column, direction)
132132
:return: The filtered query
133133
"""
134-
_partial_registry: Dict[Literal["asc", "desc"], Callable] = {
135-
"desc": partial(desc),
136-
"asc": partial(asc),
134+
_order_funcs: Dict[Literal["asc", "desc"], Callable] = {
135+
"desc": desc,
136+
"asc": asc,
137137
}
138138

139139
for value in order_by:
@@ -143,7 +143,7 @@ def _filter_order_by(
143143
else:
144144
self._validate_mapped_property(value[0])
145145
stmt = stmt.order_by(
146-
_partial_registry[value[1]](getattr(self._model, value[0]))
146+
_order_funcs[value[1]](getattr(self._model, value[0]))
147147
)
148148

149149
return stmt
@@ -344,14 +344,10 @@ def _model_pk(self) -> str:
344344
345345
:return:
346346
"""
347-
primary_keys = inspect(self._model).primary_key # type: ignore
348-
if len(primary_keys) > 1:
349-
raise NotImplementedError("Composite primary keys are not supported.")
350-
351-
return primary_keys[0].name
347+
return get_model_pk_name(self._model)
352348

353349
def _fail_if_invalid_models(self, objects: Iterable[MODEL]) -> None:
354-
if [x for x in objects if not isinstance(x, self._model)]:
350+
if any(not isinstance(x, self._model) for x in objects):
355351
raise InvalidModelError(
356352
"Cannot handle models not belonging to this repository"
357353
)

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,29 @@
1818
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1919
# DEALINGS IN THE SOFTWARE.
2020

21-
from typing import Generic, List, TypeVar, Union
21+
from typing import Generic, List, Type, TypeVar, Union
2222
from uuid import UUID
2323

2424
from pydantic import BaseModel, StrictInt, StrictStr
25+
from sqlalchemy import inspect
2526

2627
MODEL = TypeVar("MODEL")
2728
PRIMARY_KEY = Union[str, int, tuple, dict, UUID]
2829

2930

31+
def get_model_pk_name(model_class: Type) -> str:
32+
"""Retrieves the primary key column name from a SQLAlchemy model class.
33+
34+
:param model_class: A SQLAlchemy model class
35+
:return: The name of the primary key column
36+
:raises NotImplementedError: If the model has composite primary keys
37+
"""
38+
primary_keys = inspect(model_class).primary_key # type: ignore
39+
if len(primary_keys) > 1:
40+
raise NotImplementedError("Composite primary keys are not supported.")
41+
return primary_keys[0].name
42+
43+
3044
class PageInfo(BaseModel):
3145
"""
3246
Paginated query metadata.

sqlalchemy_bind_manager/_repository/result_presenters.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
from math import ceil
2222
from typing import List, Union
2323

24-
from sqlalchemy import inspect
25-
2624
from .common import (
2725
MODEL,
2826
CursorPageInfo,
2927
CursorPaginatedResult,
3028
CursorReference,
3129
PageInfo,
3230
PaginatedResult,
31+
get_model_pk_name,
3332
)
3433

3534

@@ -93,7 +92,7 @@ def _build_no_cursor_result(
9392
has_next_page = len(result_items) > items_per_page
9493
if has_next_page:
9594
result_items = result_items[0:items_per_page]
96-
reference_column = _pk_from_result_object(result_items[0])
95+
reference_column = get_model_pk_name(type(result_items[0]))
9796

9897
return CursorPaginatedResult(
9998
items=result_items,
@@ -237,11 +236,3 @@ def build_result(
237236
has_previous_page=has_previous_page,
238237
),
239238
)
240-
241-
242-
def _pk_from_result_object(model) -> str:
243-
primary_keys = inspect(type(model)).primary_key # type: ignore
244-
if len(primary_keys) > 1:
245-
raise NotImplementedError("Composite primary keys are not supported.")
246-
247-
return primary_keys[0].name

sqlalchemy_bind_manager/_session_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# DEALINGS IN THE SOFTWARE.
2020

2121
import asyncio
22+
import logging
2223
from contextlib import asynccontextmanager, contextmanager
2324
from typing import AsyncIterator, Iterator
2425

@@ -34,6 +35,8 @@
3435
)
3536
from sqlalchemy_bind_manager.exceptions import UnsupportedBindError
3637

38+
logger = logging.getLogger(__name__)
39+
3740

3841
class SessionHandler:
3942
scoped_session: scoped_session
@@ -45,8 +48,11 @@ def __init__(self, bind: SQLAlchemyBind):
4548
self.scoped_session = scoped_session(bind.session_class)
4649

4750
def __del__(self):
48-
if getattr(self, "scoped_session", None):
49-
self.scoped_session.remove()
51+
try:
52+
if getattr(self, "scoped_session", None):
53+
self.scoped_session.remove()
54+
except Exception:
55+
logger.debug("Failed to remove scoped session", exc_info=True)
5056

5157
@contextmanager
5258
def get_session(self, read_only: bool = False) -> Iterator[Session]:

tests/repository/result_presenters/test_composite_pk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import pytest
44

5-
from sqlalchemy_bind_manager._repository.result_presenters import _pk_from_result_object
5+
from sqlalchemy_bind_manager._repository.common import get_model_pk_name
66

77

88
def test_exception_raised_if_multiple_primary_keys():
99
with (
1010
patch(
11-
"sqlalchemy_bind_manager._repository.result_presenters.inspect",
11+
"sqlalchemy_bind_manager._repository.common.inspect",
1212
return_value=Mock(primary_key=["1", "2"]),
1313
),
1414
pytest.raises(NotImplementedError),
1515
):
16-
_pk_from_result_object("irrelevant")
16+
get_model_pk_name(str)

tests/session_handler/test_session_lifecycle.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ def test_sync_session_is_removed_on_cleanup(sa_manager):
3232
mocked_remove.assert_called_once()
3333

3434

35+
def test_sync_session_cleanup_handles_exception(sa_manager):
36+
"""Test that __del__ gracefully handles exceptions from scoped_session.remove()."""
37+
sh = SessionHandler(sa_manager.get_bind("sync"))
38+
39+
with patch.object(
40+
sh.scoped_session,
41+
"remove",
42+
side_effect=Exception("Connection already closed"),
43+
) as mocked_remove:
44+
# This should not raise - the exception should be caught and logged
45+
sh.__del__()
46+
47+
mocked_remove.assert_called_once()
48+
49+
3550
@pytest.mark.parametrize("read_only_flag", [True, False])
3651
async def test_commit_is_called_only_if_not_read_only(
3752
read_only_flag,

0 commit comments

Comments
 (0)