Skip to content

Commit e033786

Browse files
authored
refactoring: ♻️ scope.py
1 parent d88d2a3 commit e033786

1 file changed

Lines changed: 78 additions & 37 deletions

File tree

injection/_core/scope.py

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import itertools
34
from abc import ABC, abstractmethod
45
from collections import defaultdict
5-
from collections.abc import AsyncIterator, Iterator, MutableMapping
6+
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
67
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
78
from contextvars import ContextVar
89
from dataclasses import dataclass, field
9-
from types import TracebackType
10+
from types import EllipsisType, TracebackType
1011
from typing import (
1112
Any,
1213
AsyncContextManager,
@@ -15,6 +16,7 @@
1516
NoReturn,
1617
Protocol,
1718
Self,
19+
overload,
1820
runtime_checkable,
1921
)
2022

@@ -26,32 +28,43 @@
2628
)
2729

2830

29-
@dataclass(repr=False, slots=True)
30-
class _ScopeState:
31-
# Shouldn't be instantiated outside `__SCOPES`.
31+
@runtime_checkable
32+
class ScopeState(Protocol):
33+
__slots__ = ()
34+
35+
@property
36+
@abstractmethod
37+
def active_scopes(self) -> Iterator[Scope]:
38+
raise NotImplementedError
39+
40+
@abstractmethod
41+
def bind(self, scope: Scope) -> ContextManager[None]:
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def get_scope(self) -> Scope | None:
46+
raise NotImplementedError
47+
48+
49+
@dataclass(repr=False, frozen=True, slots=True)
50+
class _ContextualScopeState(ScopeState):
51+
# Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.
3252

3353
__context_var: ContextVar[Scope] = field(
3454
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
3555
init=False,
3656
)
37-
__default: Scope | None = field(
38-
default=None,
39-
init=False,
40-
)
4157
__references: set[Scope] = field(
4258
default_factory=set,
4359
init=False,
4460
)
4561

4662
@property
4763
def active_scopes(self) -> Iterator[Scope]:
48-
yield from self.__references
49-
50-
if default := self.__default:
51-
yield default
64+
return iter(self.__references)
5265

5366
@contextmanager
54-
def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
67+
def bind(self, scope: Scope) -> Iterator[None]:
5568
self.__references.add(scope)
5669
token = self.__context_var.set(scope)
5770

@@ -61,26 +74,38 @@ def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
6174
self.__context_var.reset(token)
6275
self.__references.remove(scope)
6376

64-
@contextmanager
65-
def bind_shared_scope(self, scope: Scope) -> Iterator[None]:
66-
if next(self.active_scopes, None):
67-
raise ScopeError(
68-
"A shared scope can't be defined when one or more contextual scopes "
69-
"are defined on the same name."
70-
)
77+
def get_scope(self) -> Scope | None:
78+
return self.__context_var.get(None)
79+
80+
81+
@dataclass(repr=False, slots=True)
82+
class _SharedScopeState(ScopeState):
83+
__scope: Scope | None = field(default=None)
84+
85+
@property
86+
def active_scopes(self) -> Iterator[Scope]:
87+
if scope := self.__scope:
88+
yield scope
7189

72-
self.__default = scope
90+
@contextmanager
91+
def bind(self, scope: Scope) -> Iterator[None]:
92+
self.__scope = scope
7393

7494
try:
7595
yield
7696
finally:
77-
self.__default = None
97+
self.__scope = None
7898

7999
def get_scope(self) -> Scope | None:
80-
return self.__context_var.get(self.__default)
100+
return self.__scope
81101

82102

83-
__SCOPES: Final[defaultdict[str, _ScopeState]] = defaultdict(_ScopeState)
103+
__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
104+
_ContextualScopeState,
105+
)
106+
__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
107+
_SharedScopeState,
108+
)
84109

85110

86111
@asynccontextmanager
@@ -98,36 +123,52 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]:
98123

99124

100125
def get_active_scopes(name: str) -> tuple[Scope, ...]:
101-
state = __SCOPES.get(name)
126+
active_scopes = (
127+
state.active_scopes
128+
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES)
129+
if (state := states.get(name))
130+
)
131+
return tuple(itertools.chain.from_iterable(active_scopes))
132+
102133

103-
if state is None:
104-
return ()
134+
@overload
135+
def get_scope(name: str, default: EllipsisType = ...) -> Scope: ...
105136

106-
return tuple(state.active_scopes)
107137

138+
@overload
139+
def get_scope[T](name: str, default: T) -> Scope | T: ...
108140

109-
def get_scope(name: str) -> Scope:
110-
state = __SCOPES.get(name)
111141

112-
if state is None or (scope := state.get_scope()) is None:
142+
def get_scope(name, default=...): # type: ignore[no-untyped-def]
143+
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES):
144+
state = states.get(name)
145+
if state and (scope := state.get_scope()):
146+
return scope
147+
148+
if default is Ellipsis:
113149
raise ScopeUndefinedError(
114150
f"Scope `{name}` isn't defined in the current context."
115151
)
116152

117-
return scope
153+
return default
118154

119155

120156
@contextmanager
121157
def _bind_scope(name: str, scope: Scope, shared: bool) -> Iterator[None]:
122-
state = __SCOPES[name]
158+
if shared:
159+
is_already_defined = bool(get_active_scopes(name))
160+
state = __SHARED_SCOPES[name]
161+
162+
else:
163+
is_already_defined = bool(get_scope(name, default=None))
164+
state = __CONTEXTUAL_SCOPES[name]
123165

124-
if state.get_scope():
166+
if is_already_defined:
125167
raise ScopeAlreadyDefinedError(
126168
f"Scope `{name}` is already defined in the current context."
127169
)
128170

129-
strategy = state.bind_shared_scope if shared else state.bind_contextual_scope
130-
with strategy(scope):
171+
with state.bind(scope):
131172
yield
132173

133174

0 commit comments

Comments
 (0)