From b7de4b2dfbd01b6f391a3d61e984456893bbddf2 Mon Sep 17 00:00:00 2001 From: remimd Date: Mon, 31 Mar 2025 11:07:56 +0200 Subject: [PATCH] =?UTF-8?q?refactoring:=20=E2=99=BB=EF=B8=8F=20`scope.py`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- injection/_core/scope.py | 115 ++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 37 deletions(-) diff --git a/injection/_core/scope.py b/injection/_core/scope.py index 910fd7a..9add310 100644 --- a/injection/_core/scope.py +++ b/injection/_core/scope.py @@ -1,12 +1,13 @@ from __future__ import annotations +import itertools from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import AsyncIterator, Iterator, MutableMapping +from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import dataclass, field -from types import TracebackType +from types import EllipsisType, TracebackType from typing import ( Any, AsyncContextManager, @@ -15,6 +16,7 @@ NoReturn, Protocol, Self, + overload, runtime_checkable, ) @@ -26,18 +28,32 @@ ) -@dataclass(repr=False, slots=True) -class _ScopeState: - # Shouldn't be instantiated outside `__SCOPES`. +@runtime_checkable +class ScopeState(Protocol): + __slots__ = () + + @property + @abstractmethod + def active_scopes(self) -> Iterator[Scope]: + raise NotImplementedError + + @abstractmethod + def bind(self, scope: Scope) -> ContextManager[None]: + raise NotImplementedError + + @abstractmethod + def get_scope(self) -> Scope | None: + raise NotImplementedError + + +@dataclass(repr=False, frozen=True, slots=True) +class _ContextualScopeState(ScopeState): + # Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`. __context_var: ContextVar[Scope] = field( default_factory=lambda: ContextVar(f"scope@{new_short_key()}"), init=False, ) - __default: Scope | None = field( - default=None, - init=False, - ) __references: set[Scope] = field( default_factory=set, init=False, @@ -45,13 +61,10 @@ class _ScopeState: @property def active_scopes(self) -> Iterator[Scope]: - yield from self.__references - - if default := self.__default: - yield default + return iter(self.__references) @contextmanager - def bind_contextual_scope(self, scope: Scope) -> Iterator[None]: + def bind(self, scope: Scope) -> Iterator[None]: self.__references.add(scope) token = self.__context_var.set(scope) @@ -61,26 +74,38 @@ def bind_contextual_scope(self, scope: Scope) -> Iterator[None]: self.__context_var.reset(token) self.__references.remove(scope) - @contextmanager - def bind_shared_scope(self, scope: Scope) -> Iterator[None]: - if next(self.active_scopes, None): - raise ScopeError( - "A shared scope can't be defined when one or more contextual scopes " - "are defined on the same name." - ) + def get_scope(self) -> Scope | None: + return self.__context_var.get(None) + + +@dataclass(repr=False, slots=True) +class _SharedScopeState(ScopeState): + __scope: Scope | None = field(default=None) + + @property + def active_scopes(self) -> Iterator[Scope]: + if scope := self.__scope: + yield scope - self.__default = scope + @contextmanager + def bind(self, scope: Scope) -> Iterator[None]: + self.__scope = scope try: yield finally: - self.__default = None + self.__scope = None def get_scope(self) -> Scope | None: - return self.__context_var.get(self.__default) + return self.__scope -__SCOPES: Final[defaultdict[str, _ScopeState]] = defaultdict(_ScopeState) +__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict( + _ContextualScopeState, +) +__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict( + _SharedScopeState, +) @asynccontextmanager @@ -98,36 +123,52 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]: def get_active_scopes(name: str) -> tuple[Scope, ...]: - state = __SCOPES.get(name) + active_scopes = ( + state.active_scopes + for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES) + if (state := states.get(name)) + ) + return tuple(itertools.chain.from_iterable(active_scopes)) + - if state is None: - return () +@overload +def get_scope(name: str, default: EllipsisType = ...) -> Scope: ... - return tuple(state.active_scopes) +@overload +def get_scope[T](name: str, default: T) -> Scope | T: ... -def get_scope(name: str) -> Scope: - state = __SCOPES.get(name) - if state is None or (scope := state.get_scope()) is None: +def get_scope(name, default=...): # type: ignore[no-untyped-def] + for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES): + state = states.get(name) + if state and (scope := state.get_scope()): + return scope + + if default is Ellipsis: raise ScopeUndefinedError( f"Scope `{name}` isn't defined in the current context." ) - return scope + return default @contextmanager def _bind_scope(name: str, scope: Scope, shared: bool) -> Iterator[None]: - state = __SCOPES[name] + if shared: + is_already_defined = bool(get_active_scopes(name)) + state = __SHARED_SCOPES[name] + + else: + is_already_defined = bool(get_scope(name, default=None)) + state = __CONTEXTUAL_SCOPES[name] - if state.get_scope(): + if is_already_defined: raise ScopeAlreadyDefinedError( f"Scope `{name}` is already defined in the current context." ) - strategy = state.bind_shared_scope if shared else state.bind_contextual_scope - with strategy(scope): + with state.bind(scope): yield