11from __future__ import annotations
22
3+ import itertools
34from abc import ABC , abstractmethod
45from collections import defaultdict
5- from collections .abc import AsyncIterator , Iterator , MutableMapping
6+ from collections .abc import AsyncIterator , Iterator , Mapping , MutableMapping
67from contextlib import AsyncExitStack , ExitStack , asynccontextmanager , contextmanager
78from contextvars import ContextVar
89from dataclasses import dataclass , field
9- from types import TracebackType
10+ from types import EllipsisType , TracebackType
1011from typing import (
1112 Any ,
1213 AsyncContextManager ,
1516 NoReturn ,
1617 Protocol ,
1718 Self ,
19+ overload ,
1820 runtime_checkable ,
1921)
2022
2628)
2729
2830
29- @dataclass (repr = False , slots = True )
30- class _ScopeState :
31- # Shouldn't be instantiated outside `__SCOPES`.
31+ @runtime_checkable
32+ class ScopeState (Protocol ):
33+ __slots__ = ()
34+
35+ @property
36+ @abstractmethod
37+ def active_scopes (self ) -> Iterator [Scope ]:
38+ raise NotImplementedError
39+
40+ @abstractmethod
41+ def bind (self , scope : Scope ) -> ContextManager [None ]:
42+ raise NotImplementedError
43+
44+ @abstractmethod
45+ def get_scope (self ) -> Scope | None :
46+ raise NotImplementedError
47+
48+
49+ @dataclass (repr = False , frozen = True , slots = True )
50+ class _ContextualScopeState (ScopeState ):
51+ # Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.
3252
3353 __context_var : ContextVar [Scope ] = field (
3454 default_factory = lambda : ContextVar (f"scope@{ new_short_key ()} " ),
3555 init = False ,
3656 )
37- __default : Scope | None = field (
38- default = None ,
39- init = False ,
40- )
4157 __references : set [Scope ] = field (
4258 default_factory = set ,
4359 init = False ,
4460 )
4561
4662 @property
4763 def active_scopes (self ) -> Iterator [Scope ]:
48- yield from self .__references
49-
50- if default := self .__default :
51- yield default
64+ return iter (self .__references )
5265
5366 @contextmanager
54- def bind_contextual_scope (self , scope : Scope ) -> Iterator [None ]:
67+ def bind (self , scope : Scope ) -> Iterator [None ]:
5568 self .__references .add (scope )
5669 token = self .__context_var .set (scope )
5770
@@ -61,26 +74,38 @@ def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
6174 self .__context_var .reset (token )
6275 self .__references .remove (scope )
6376
64- @contextmanager
65- def bind_shared_scope (self , scope : Scope ) -> Iterator [None ]:
66- if next (self .active_scopes , None ):
67- raise ScopeError (
68- "A shared scope can't be defined when one or more contextual scopes "
69- "are defined on the same name."
70- )
77+ def get_scope (self ) -> Scope | None :
78+ return self .__context_var .get (None )
79+
80+
81+ @dataclass (repr = False , slots = True )
82+ class _SharedScopeState (ScopeState ):
83+ __scope : Scope | None = field (default = None )
84+
85+ @property
86+ def active_scopes (self ) -> Iterator [Scope ]:
87+ if scope := self .__scope :
88+ yield scope
7189
72- self .__default = scope
90+ @contextmanager
91+ def bind (self , scope : Scope ) -> Iterator [None ]:
92+ self .__scope = scope
7393
7494 try :
7595 yield
7696 finally :
77- self .__default = None
97+ self .__scope = None
7898
7999 def get_scope (self ) -> Scope | None :
80- return self .__context_var . get ( self . __default )
100+ return self .__scope
81101
82102
83- __SCOPES : Final [defaultdict [str , _ScopeState ]] = defaultdict (_ScopeState )
103+ __CONTEXTUAL_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
104+ _ContextualScopeState ,
105+ )
106+ __SHARED_SCOPES : Final [Mapping [str , ScopeState ]] = defaultdict (
107+ _SharedScopeState ,
108+ )
84109
85110
86111@asynccontextmanager
@@ -98,36 +123,52 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]:
98123
99124
100125def get_active_scopes (name : str ) -> tuple [Scope , ...]:
101- state = __SCOPES .get (name )
126+ active_scopes = (
127+ state .active_scopes
128+ for states in (__CONTEXTUAL_SCOPES , __SHARED_SCOPES )
129+ if (state := states .get (name ))
130+ )
131+ return tuple (itertools .chain .from_iterable (active_scopes ))
132+
102133
103- if state is None :
104- return ()
134+ @ overload
135+ def get_scope ( name : str , default : EllipsisType = ...) -> Scope : ...
105136
106- return tuple (state .active_scopes )
107137
138+ @overload
139+ def get_scope [T ](name : str , default : T ) -> Scope | T : ...
108140
109- def get_scope (name : str ) -> Scope :
110- state = __SCOPES .get (name )
111141
112- if state is None or (scope := state .get_scope ()) is None :
142+ def get_scope (name , default = ...): # type: ignore[no-untyped-def]
143+ for states in (__CONTEXTUAL_SCOPES , __SHARED_SCOPES ):
144+ state = states .get (name )
145+ if state and (scope := state .get_scope ()):
146+ return scope
147+
148+ if default is Ellipsis :
113149 raise ScopeUndefinedError (
114150 f"Scope `{ name } ` isn't defined in the current context."
115151 )
116152
117- return scope
153+ return default
118154
119155
120156@contextmanager
121157def _bind_scope (name : str , scope : Scope , shared : bool ) -> Iterator [None ]:
122- state = __SCOPES [name ]
158+ if shared :
159+ is_already_defined = bool (get_active_scopes (name ))
160+ state = __SHARED_SCOPES [name ]
161+
162+ else :
163+ is_already_defined = bool (get_scope (name , default = None ))
164+ state = __CONTEXTUAL_SCOPES [name ]
123165
124- if state . get_scope () :
166+ if is_already_defined :
125167 raise ScopeAlreadyDefinedError (
126168 f"Scope `{ name } ` is already defined in the current context."
127169 )
128170
129- strategy = state .bind_shared_scope if shared else state .bind_contextual_scope
130- with strategy (scope ):
171+ with state .bind (scope ):
131172 yield
132173
133174
0 commit comments