55import typing
66from contextvars import ContextVar
77from types import TracebackType
8+ from typing import Optional
89from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
910
1011from sqlalchemy import text
@@ -63,8 +64,13 @@ def __init__(
6364 assert issubclass (backend_cls , DatabaseBackend )
6465 self ._backend = backend_cls (self .url , ** self .options )
6566
66- # Connections are stored as task-local state.
67- self ._connection_context : ContextVar = ContextVar ("connection_context" )
67+ # Connections are stored as task-local state, and cannot be garbage collected,
68+ # since the immutable global Context stores a strong reference to each ContextVar
69+ # that is created. We need these local ContextVars since two Database objects
70+ # could run in the same asyncio.Task with connections to different databases.
71+ self ._connection_contextvar : ContextVar [Optional ["Connection" ]] = ContextVar (
72+ f"databases:Database:{ id (self )} "
73+ )
6874
6975 # When `force_rollback=True` is used, we use a single global
7076 # connection, within a transaction that always rolls back.
@@ -113,7 +119,7 @@ async def disconnect(self) -> None:
113119 self ._global_transaction = None
114120 self ._global_connection = None
115121 else :
116- self ._connection_context = ContextVar ( "connection_context" )
122+ self ._connection_contextvar . set ( None )
117123
118124 await self ._backend .disconnect ()
119125 logger .info (
@@ -187,12 +193,12 @@ def connection(self) -> "Connection":
187193 if self ._global_connection is not None :
188194 return self ._global_connection
189195
190- try :
191- return self ._connection_context .get ()
192- except LookupError :
196+ connection = self ._connection_contextvar .get (default = None )
197+ if connection is None :
193198 connection = Connection (self ._backend )
194- self ._connection_context .set (connection )
195- return connection
199+ self ._connection_contextvar .set (connection )
200+
201+ return connection
196202
197203 def transaction (
198204 self , * , force_rollback : bool = False , ** kwargs : typing .Any
@@ -344,9 +350,15 @@ def __init__(
344350 self ._connection_callable = connection_callable
345351 self ._force_rollback = force_rollback
346352 self ._extra_options = kwargs
347- self ._transaction_context : ContextVar [TransactionBackend | None ] = ContextVar (
348- "transaction_context"
349- )
353+
354+ # This ContextVar can never be garbage collected - similar to the ContextVar
355+ # at Database._connection_contextvar - since the current Context has a strong
356+ # reference to every ContextVar that is created. We need local ContextVars since
357+ # there may be multiple (even nested) transactions in a single asyncio.Task,
358+ # which each need their own unique TransactionBackend object.
359+ self ._transaction_contextvar : ContextVar [
360+ Optional [TransactionBackend ]
361+ ] = ContextVar (f"databases:Transaction:{ id (self )} " )
350362
351363 async def __aenter__ (self ) -> "Transaction" :
352364 """
@@ -390,7 +402,11 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
390402 async def start (self ) -> "Transaction" :
391403 connection = self ._connection_callable ()
392404 transaction = connection ._connection .transaction ()
393- self ._transaction_context .set (transaction )
405+
406+ # Cannot store returned reset token anywhere, for the same reason
407+ # we need a ContextVar in the first place - `self` is not
408+ # a safe object on which to store references for concurrent code.
409+ self ._transaction_contextvar .set (transaction )
394410
395411 async with connection ._transaction_lock :
396412 is_root = not connection ._transaction_stack
@@ -401,25 +417,27 @@ async def start(self) -> "Transaction":
401417
402418 async def commit (self ) -> None :
403419 connection = self ._connection_callable ()
404- transaction = self ._transaction_context .get ()
420+ transaction = self ._transaction_contextvar .get (default = None )
405421 assert transaction is not None , "Transaction not found in current task"
406422 async with connection ._transaction_lock :
407423 assert connection ._transaction_stack [- 1 ] is self
408424 connection ._transaction_stack .pop ()
409425 await transaction .commit ()
410426 await connection .__aexit__ ()
411- self ._transaction_context .set (None )
427+ # Have no reset token, set to None instead
428+ self ._transaction_contextvar .set (None )
412429
413430 async def rollback (self ) -> None :
414431 connection = self ._connection_callable ()
415- transaction = self ._transaction_context .get ()
432+ transaction = self ._transaction_contextvar .get (default = None )
416433 assert transaction is not None , "Transaction not found in current task"
417434 async with connection ._transaction_lock :
418435 assert connection ._transaction_stack [- 1 ] is self
419436 connection ._transaction_stack .pop ()
420437 await transaction .rollback ()
421438 await connection .__aexit__ ()
422- self ._transaction_context .set (None )
439+ # Have no reset token, set to None instead
440+ self ._transaction_contextvar .set (None )
423441
424442
425443class _EmptyNetloc (str ):
0 commit comments