|
7 | 7 | from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager |
8 | 8 | from contextvars import ContextVar |
9 | 9 | from dataclasses import dataclass, field |
| 10 | +from enum import StrEnum |
10 | 11 | from types import EllipsisType, TracebackType |
11 | 12 | from typing import ( |
12 | 13 | Any, |
13 | 14 | AsyncContextManager, |
14 | 15 | ContextManager, |
15 | 16 | Final, |
| 17 | + Literal, |
16 | 18 | NoReturn, |
17 | 19 | Protocol, |
18 | 20 | Self, |
|
21 | 23 | ) |
22 | 24 |
|
23 | 25 | from injection._core.common.key import new_short_key |
| 26 | +from injection._core.slots import Slot |
24 | 27 | from injection.exceptions import ( |
| 28 | + InjectionError, |
25 | 29 | ScopeAlreadyDefinedError, |
26 | 30 | ScopeError, |
27 | 31 | ScopeUndefinedError, |
28 | 32 | ) |
29 | 33 |
|
30 | 34 |
|
| 35 | +class ScopeKind(StrEnum): |
| 36 | + CONTEXTUAL = "contextual" |
| 37 | + SHARED = "shared" |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def get_default(cls) -> ScopeKind: |
| 41 | + return cls.CONTEXTUAL |
| 42 | + |
| 43 | + |
| 44 | +type ScopeKindStr = Literal["contextual", "shared"] |
| 45 | + |
| 46 | + |
31 | 47 | @runtime_checkable |
32 | 48 | class ScopeState(Protocol): |
33 | 49 | __slots__ = () |
@@ -109,17 +125,21 @@ def get_scope(self) -> Scope | None: |
109 | 125 |
|
110 | 126 |
|
111 | 127 | @asynccontextmanager |
112 | | -async def adefine_scope(name: str, *, shared: bool = False) -> AsyncIterator[None]: |
| 128 | +async def adefine_scope( |
| 129 | + name: str, |
| 130 | + kind: ScopeKind | ScopeKindStr = ScopeKind.get_default(), |
| 131 | +) -> AsyncIterator[ScopeFacade]: |
113 | 132 | async with AsyncScope() as scope: |
114 | | - scope.enter(_bind_scope(name, scope, shared)) |
115 | | - yield |
| 133 | + yield scope.enter(_bind_scope(name, scope, kind)) |
116 | 134 |
|
117 | 135 |
|
118 | 136 | @contextmanager |
119 | | -def define_scope(name: str, *, shared: bool = False) -> Iterator[None]: |
| 137 | +def define_scope( |
| 138 | + name: str, |
| 139 | + kind: ScopeKind | ScopeKindStr = ScopeKind.get_default(), |
| 140 | +) -> Iterator[ScopeFacade]: |
120 | 141 | with SyncScope() as scope: |
121 | | - scope.enter(_bind_scope(name, scope, shared)) |
122 | | - yield |
| 142 | + yield scope.enter(_bind_scope(name, scope, kind)) |
123 | 143 |
|
124 | 144 |
|
125 | 145 | def get_active_scopes(name: str) -> tuple[Scope, ...]: |
@@ -153,23 +173,40 @@ def get_scope(name, default=...): # type: ignore[no-untyped-def] |
153 | 173 | return default |
154 | 174 |
|
155 | 175 |
|
156 | | -@contextmanager |
157 | | -def _bind_scope(name: str, scope: Scope, shared: bool) -> Iterator[None]: |
158 | | - if shared: |
159 | | - is_already_defined = bool(get_active_scopes(name)) |
160 | | - states = __SHARED_SCOPES |
| 176 | +def in_scope_cache(key: Any, scope_name: str) -> bool: |
| 177 | + return any(key in scope.cache for scope in get_active_scopes(scope_name)) |
| 178 | + |
| 179 | + |
| 180 | +def remove_scoped_values(key: Any, scope_name: str) -> None: |
| 181 | + for scope in get_active_scopes(scope_name): |
| 182 | + scope.cache.pop(key, None) |
| 183 | + |
161 | 184 |
|
162 | | - else: |
163 | | - is_already_defined = bool(get_scope(name, default=None)) |
164 | | - states = __CONTEXTUAL_SCOPES |
| 185 | +@contextmanager |
| 186 | +def _bind_scope( |
| 187 | + name: str, |
| 188 | + scope: Scope, |
| 189 | + kind: ScopeKind | ScopeKindStr, |
| 190 | +) -> Iterator[ScopeFacade]: |
| 191 | + match ScopeKind(kind): |
| 192 | + case ScopeKind.CONTEXTUAL: |
| 193 | + is_already_defined = bool(get_scope(name, default=None)) |
| 194 | + states = __CONTEXTUAL_SCOPES |
| 195 | + |
| 196 | + case ScopeKind.SHARED: |
| 197 | + is_already_defined = bool(get_active_scopes(name)) |
| 198 | + states = __SHARED_SCOPES |
| 199 | + |
| 200 | + case _: |
| 201 | + raise NotImplementedError |
165 | 202 |
|
166 | 203 | if is_already_defined: |
167 | 204 | raise ScopeAlreadyDefinedError( |
168 | 205 | f"Scope `{name}` is already defined in the current context." |
169 | 206 | ) |
170 | 207 |
|
171 | 208 | with states[name].bind(scope): |
172 | | - yield |
| 209 | + yield ScopeFacade(scope) |
173 | 210 |
|
174 | 211 |
|
175 | 212 | @runtime_checkable |
@@ -245,3 +282,21 @@ async def aenter[T](self, context_manager: AsyncContextManager[T]) -> NoReturn: |
245 | 282 |
|
246 | 283 | def enter[T](self, context_manager: ContextManager[T]) -> T: |
247 | 284 | return self.delegate.enter_context(context_manager) |
| 285 | + |
| 286 | + |
| 287 | +@dataclass(repr=False, frozen=True, slots=True) |
| 288 | +class ScopeFacade: |
| 289 | + scope: Scope |
| 290 | + |
| 291 | + def set_slot[T](self, slot: Slot[T], value: T) -> Self: |
| 292 | + return self.slot_map({slot: value}) |
| 293 | + |
| 294 | + def slot_map(self, values: Mapping[Slot[Any], Any]) -> Self: |
| 295 | + cache = self.scope.cache |
| 296 | + |
| 297 | + for slot in values: |
| 298 | + if slot in cache: |
| 299 | + raise InjectionError("Slot already set.") |
| 300 | + |
| 301 | + cache.update(values) |
| 302 | + return self |
0 commit comments