55import typing
66from contextvars import ContextVar
77from types import TracebackType
8- from typing import Optional
8+ from typing import Dict , Optional
99from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
1010
1111from sqlalchemy import text
3535
3636logger = logging .getLogger ("databases" )
3737
38+ # Connections are stored as task-local state, but care must be taken to ensure
39+ # that two database instances in the same task overwrite each other's connections.
40+ # For this reason, key comprises the database instance and the current task.
41+ _connection_contextmap : ContextVar [
42+ Dict [tuple ["Database" , asyncio .Task ], "Connection" ]
43+ ] = ContextVar ("databases:Connection" )
44+
45+
46+ def _get_connection_contextmap () -> Dict [tuple ["Database" , asyncio .Task ], "Connection" ]:
47+ connections = _connection_contextmap .get (None )
48+ if connections is None :
49+ connections = {}
50+ _connection_contextmap .set (connections )
51+ return connections
52+
3853
3954class Database :
4055 SUPPORTED_BACKENDS = {
@@ -64,14 +79,6 @@ def __init__(
6479 assert issubclass (backend_cls , DatabaseBackend )
6580 self ._backend = backend_cls (self .url , ** self .options )
6681
67- # Connections are stored as task-local state, and cannot be garbage collected,
68- # since the immutable global Context stores a strong reference to each ContextVar
69- # that is created. We need these local ContextVars since two Database objects
70- # could run in the same asyncio.Task with connections to different databases.
71- self ._connection_contextvar : ContextVar [Optional ["Connection" ]] = ContextVar (
72- f"databases:Database:{ id (self )} "
73- )
74-
7582 # When `force_rollback=True` is used, we use a single global
7683 # connection, within a transaction that always rolls back.
7784 self ._global_connection : typing .Optional [Connection ] = None
@@ -119,7 +126,10 @@ async def disconnect(self) -> None:
119126 self ._global_transaction = None
120127 self ._global_connection = None
121128 else :
122- self ._connection_contextvar .set (None )
129+ task = asyncio .current_task ()
130+ connections = _get_connection_contextmap ()
131+ if (self , task ) in connections :
132+ del connections [self , task ]
123133
124134 await self ._backend .disconnect ()
125135 logger .info (
@@ -193,12 +203,12 @@ def connection(self) -> "Connection":
193203 if self ._global_connection is not None :
194204 return self ._global_connection
195205
196- connection = self . _connection_contextvar . get ( None )
197- if connection is None :
198- connection = Connection ( self . _backend )
199- self . _connection_contextvar . set ( connection )
206+ task = asyncio . current_task ( )
207+ connections = _get_connection_contextmap ()
208+ if ( self , task ) not in connections :
209+ connections [ self , task ] = Connection ( self . _backend )
200210
201- return connection
211+ return connections [ self , task ]
202212
203213 def transaction (
204214 self , * , force_rollback : bool = False , ** kwargs : typing .Any
@@ -339,6 +349,19 @@ def _build_query(
339349
340350_CallableType = typing .TypeVar ("_CallableType" , bound = typing .Callable )
341351
352+ _transaction_contextmap : ContextVar [
353+ Dict ["Transaction" , TransactionBackend ]
354+ ] = ContextVar ("databases:Transactions" )
355+
356+
357+ def _get_transaction_contextmap () -> Dict ["Transaction" , TransactionBackend ]:
358+ transactions = _transaction_contextmap .get (None )
359+ if transactions is None :
360+ transactions = {}
361+ _transaction_contextmap .set (transactions )
362+
363+ return transactions
364+
342365
343366class Transaction :
344367 def __init__ (
@@ -351,15 +374,6 @@ def __init__(
351374 self ._force_rollback = force_rollback
352375 self ._extra_options = kwargs
353376
354- # This ContextVar can never be garbage collected - similar to the ContextVar
355- # at Database._connection_contextvar - since the current Context has a strong
356- # reference to every ContextVar that is created. We need local ContextVars since
357- # there may be multiple (even nested) transactions in a single asyncio.Task,
358- # which each need their own unique TransactionBackend object.
359- self ._transaction_contextvar : ContextVar [
360- Optional [TransactionBackend ]
361- ] = ContextVar (f"databases:Transaction:{ id (self )} " )
362-
363377 async def __aenter__ (self ) -> "Transaction" :
364378 """
365379 Called when entering `async with database.transaction()`
@@ -402,12 +416,8 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
402416 async def start (self ) -> "Transaction" :
403417 connection = self ._connection_callable ()
404418 transaction = connection ._connection .transaction ()
405-
406- # Cannot store returned reset token anywhere, for the same reason
407- # we need a ContextVar in the first place - `self` is not
408- # a safe object on which to store references for concurrent code.
409- self ._transaction_contextvar .set (transaction )
410-
419+ transactions = _get_transaction_contextmap ()
420+ transactions [self ] = transaction
411421 async with connection ._transaction_lock :
412422 is_root = not connection ._transaction_stack
413423 await connection .__aenter__ ()
@@ -417,27 +427,27 @@ async def start(self) -> "Transaction":
417427
418428 async def commit (self ) -> None :
419429 connection = self ._connection_callable ()
420- transaction = self ._transaction_contextvar .get (None )
430+ transactions = _get_transaction_contextmap ()
431+ transaction = transactions .get (self , None )
421432 assert transaction is not None , "Transaction not found in current task"
422433 async with connection ._transaction_lock :
423434 assert connection ._transaction_stack [- 1 ] is self
424435 connection ._transaction_stack .pop ()
425436 await transaction .commit ()
426437 await connection .__aexit__ ()
427- # Have no reset token, set to None instead
428- self ._transaction_contextvar .set (None )
438+ del transactions [self ]
429439
430440 async def rollback (self ) -> None :
431441 connection = self ._connection_callable ()
432- transaction = self ._transaction_contextvar .get (None )
442+ transactions = _get_transaction_contextmap ()
443+ transaction = transactions .get (self , None )
433444 assert transaction is not None , "Transaction not found in current task"
434445 async with connection ._transaction_lock :
435446 assert connection ._transaction_stack [- 1 ] is self
436447 connection ._transaction_stack .pop ()
437448 await transaction .rollback ()
438449 await connection .__aexit__ ()
439- # Have no reset token, set to None instead
440- self ._transaction_contextvar .set (None )
450+ del transactions [self ]
441451
442452
443453class _EmptyNetloc (str ):
0 commit comments