Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 60fd96a

Browse files
Add 'database.force_rollback()' context manager. (#189)
* Add 'database.force_rollback()' context manager * Add .force_rollback test case * Typing
1 parent df10fe8 commit 60fd96a

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

databases/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import functools
34
import logging
45
import sys
@@ -186,6 +187,15 @@ def connection(self) -> "Connection":
186187
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
187188
return self.connection().transaction(force_rollback=force_rollback)
188189

190+
@contextlib.contextmanager
191+
def force_rollback(self) -> typing.Iterator[None]:
192+
initial = self._force_rollback
193+
self._force_rollback = True
194+
try:
195+
yield
196+
finally:
197+
self._force_rollback = initial
198+
189199

190200
class Connection:
191201
def __init__(self, backend: DatabaseBackend) -> None:

tests/test_databases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,28 @@ async def test_rollback_isolation(database_url):
337337
assert len(results) == 0
338338

339339

340+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
341+
@async_adapter
342+
async def test_rollback_isolation_with_contextmanager(database_url):
343+
"""
344+
Ensure that `database.force_rollback()` provides strict isolation.
345+
"""
346+
347+
database = Database(database_url)
348+
349+
with database.force_rollback():
350+
async with database:
351+
# Perform some INSERT operations on the database.
352+
query = notes.insert().values(text="example1", completed=True)
353+
await database.execute(query)
354+
355+
async with database:
356+
# Ensure INSERT operations have been rolled back.
357+
query = notes.select()
358+
results = await database.fetch_all(query=query)
359+
assert len(results) == 0
360+
361+
340362
@pytest.mark.parametrize("database_url", DATABASE_URLS)
341363
@async_adapter
342364
async def test_transaction_commit(database_url):

0 commit comments

Comments
 (0)