Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 78 additions & 37 deletions injection/_core/scope.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import AsyncIterator, Iterator, MutableMapping
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from types import TracebackType
from types import EllipsisType, TracebackType
from typing import (
Any,
AsyncContextManager,
Expand All @@ -15,6 +16,7 @@
NoReturn,
Protocol,
Self,
overload,
runtime_checkable,
)

Expand All @@ -26,32 +28,43 @@
)


@dataclass(repr=False, slots=True)
class _ScopeState:
# Shouldn't be instantiated outside `__SCOPES`.
@runtime_checkable
class ScopeState(Protocol):
__slots__ = ()

@property
@abstractmethod
def active_scopes(self) -> Iterator[Scope]:
raise NotImplementedError

@abstractmethod
def bind(self, scope: Scope) -> ContextManager[None]:
raise NotImplementedError

@abstractmethod
def get_scope(self) -> Scope | None:
raise NotImplementedError


@dataclass(repr=False, frozen=True, slots=True)
class _ContextualScopeState(ScopeState):
# Shouldn't be instantiated outside `__CONTEXTUAL_SCOPES`.

__context_var: ContextVar[Scope] = field(
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
init=False,
)
__default: Scope | None = field(
default=None,
init=False,
)
__references: set[Scope] = field(
default_factory=set,
init=False,
)

@property
def active_scopes(self) -> Iterator[Scope]:
yield from self.__references

if default := self.__default:
yield default
return iter(self.__references)

@contextmanager
def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
def bind(self, scope: Scope) -> Iterator[None]:
self.__references.add(scope)
token = self.__context_var.set(scope)

Expand All @@ -61,26 +74,38 @@ def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
self.__context_var.reset(token)
self.__references.remove(scope)

@contextmanager
def bind_shared_scope(self, scope: Scope) -> Iterator[None]:
if next(self.active_scopes, None):
raise ScopeError(
"A shared scope can't be defined when one or more contextual scopes "
"are defined on the same name."
)
def get_scope(self) -> Scope | None:
return self.__context_var.get(None)


@dataclass(repr=False, slots=True)
class _SharedScopeState(ScopeState):
__scope: Scope | None = field(default=None)

@property
def active_scopes(self) -> Iterator[Scope]:
if scope := self.__scope:
yield scope

self.__default = scope
@contextmanager
def bind(self, scope: Scope) -> Iterator[None]:
self.__scope = scope

try:
yield
finally:
self.__default = None
self.__scope = None

def get_scope(self) -> Scope | None:
return self.__context_var.get(self.__default)
return self.__scope


__SCOPES: Final[defaultdict[str, _ScopeState]] = defaultdict(_ScopeState)
__CONTEXTUAL_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
_ContextualScopeState,
)
__SHARED_SCOPES: Final[Mapping[str, ScopeState]] = defaultdict(
_SharedScopeState,
)


@asynccontextmanager
Expand All @@ -98,36 +123,52 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]:


def get_active_scopes(name: str) -> tuple[Scope, ...]:
state = __SCOPES.get(name)
active_scopes = (
state.active_scopes
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES)
if (state := states.get(name))
)
return tuple(itertools.chain.from_iterable(active_scopes))


if state is None:
return ()
@overload
def get_scope(name: str, default: EllipsisType = ...) -> Scope: ...

return tuple(state.active_scopes)

@overload
def get_scope[T](name: str, default: T) -> Scope | T: ...

def get_scope(name: str) -> Scope:
state = __SCOPES.get(name)

if state is None or (scope := state.get_scope()) is None:
def get_scope(name, default=...): # type: ignore[no-untyped-def]
for states in (__CONTEXTUAL_SCOPES, __SHARED_SCOPES):
state = states.get(name)
if state and (scope := state.get_scope()):
return scope

if default is Ellipsis:
raise ScopeUndefinedError(
f"Scope `{name}` isn't defined in the current context."
)

return scope
return default


@contextmanager
def _bind_scope(name: str, scope: Scope, shared: bool) -> Iterator[None]:
state = __SCOPES[name]
if shared:
is_already_defined = bool(get_active_scopes(name))
state = __SHARED_SCOPES[name]

else:
is_already_defined = bool(get_scope(name, default=None))
state = __CONTEXTUAL_SCOPES[name]

if state.get_scope():
if is_already_defined:
raise ScopeAlreadyDefinedError(
f"Scope `{name}` is already defined in the current context."
)

strategy = state.bind_shared_scope if shared else state.bind_contextual_scope
with strategy(scope):
with state.bind(scope):
yield


Expand Down