Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 15ecd04

Browse files
Failing test case for database.transaction() as a decorator (#192)
* Failing test case for database.transaction() as a decorator * Rollback #158 * Tweak .force_rollback * Revert "Tweak .force_rollback" This reverts commit 2835db7. * Revert "Rollback #158" This reverts commit c18b19c. * Fix transction decorator * Linting * Version 0.3.1
1 parent 60fd96a commit 15ecd04

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

databases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from databases.core import Database, DatabaseURL
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.3.1"
44
__all__ = ["Database", "DatabaseURL"]

databases/core.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def connection(self) -> "Connection":
185185
return connection
186186

187187
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
188-
return self.connection().transaction(force_rollback=force_rollback)
188+
return Transaction(self.connection, force_rollback=force_rollback)
189189

190190
@contextlib.contextmanager
191191
def force_rollback(self) -> typing.Iterator[None]:
@@ -275,7 +275,10 @@ async def iterate(
275275
yield record
276276

277277
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
278-
return Transaction(self, force_rollback)
278+
def connection_callable() -> Connection:
279+
return self
280+
281+
return Transaction(connection_callable, force_rollback)
279282

280283
@property
281284
def raw_connection(self) -> typing.Any:
@@ -296,10 +299,13 @@ def _build_query(
296299

297300

298301
class Transaction:
299-
def __init__(self, connection: Connection, force_rollback: bool) -> None:
300-
self._connection = connection
302+
def __init__(
303+
self,
304+
connection_callable: typing.Callable[[], Connection],
305+
force_rollback: bool,
306+
) -> None:
307+
self._connection_callable = connection_callable
301308
self._force_rollback = force_rollback
302-
self._transaction = connection._connection.transaction()
303309

304310
async def __aenter__(self) -> "Transaction":
305311
"""
@@ -341,6 +347,9 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
341347
return wrapper
342348

343349
async def start(self) -> "Transaction":
350+
self._connection = self._connection_callable()
351+
self._transaction = self._connection._connection.transaction()
352+
344353
async with self._connection._transaction_lock:
345354
is_root = not self._connection._transaction_stack
346355
await self._connection.__aenter__()

tests/test_databases.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,15 +451,16 @@ async def test_transaction_decorator(database_url):
451451
"""
452452
Ensure that @database.transaction() is supported.
453453
"""
454-
async with Database(database_url, force_rollback=True) as database:
454+
database = Database(database_url, force_rollback=True)
455455

456-
@database.transaction()
457-
async def insert_data(raise_exception):
458-
query = notes.insert().values(text="example", completed=True)
459-
await database.execute(query)
460-
if raise_exception:
461-
raise RuntimeError()
456+
@database.transaction()
457+
async def insert_data(raise_exception):
458+
query = notes.insert().values(text="example", completed=True)
459+
await database.execute(query)
460+
if raise_exception:
461+
raise RuntimeError()
462462

463+
async with database:
463464
with pytest.raises(RuntimeError):
464465
await insert_data(raise_exception=True)
465466

0 commit comments

Comments
 (0)