Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion injection/_core/common/asynchronous.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Awaitable, Callable, Generator
from dataclasses import dataclass
from typing import Any, NoReturn, Protocol, runtime_checkable
from typing import Any, AsyncContextManager, NoReturn, Protocol, runtime_checkable


@dataclass(repr=False, eq=False, frozen=True, slots=True)
Expand Down Expand Up @@ -45,3 +45,15 @@ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:

def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.callable(*args, **kwargs)


try:
import anyio

def create_semaphore(value: int) -> AsyncContextManager[Any]:
return anyio.Semaphore(value)
except ImportError: # pragma: no cover
import asyncio

def create_semaphore(value: int) -> AsyncContextManager[Any]:
return asyncio.Semaphore(value)
100 changes: 57 additions & 43 deletions injection/_core/injectables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from collections.abc import MutableMapping
from collections.abc import Awaitable, Callable, MutableMapping
from contextlib import suppress
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from typing import (
Any,
AsyncContextManager,
Expand All @@ -12,7 +13,7 @@
runtime_checkable,
)

from injection._core.common.asynchronous import Caller
from injection._core.common.asynchronous import Caller, create_semaphore
from injection._core.scope import Scope, get_active_scopes, get_scope
from injection.exceptions import InjectionError

Expand All @@ -37,12 +38,12 @@ def get_instance(self) -> T:
raise NotImplementedError


@dataclass(repr=False, frozen=True, slots=True)
class BaseInjectable[T](Injectable[T], ABC):
factory: Caller[..., T]
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class BaseInjectable[R, T](Injectable[T], ABC):
factory: Caller[..., R]


class SimpleInjectable[T](BaseInjectable[T]):
class SimpleInjectable[T](BaseInjectable[T, T]):
__slots__ = ()

async def aget_instance(self) -> T:
Expand All @@ -52,7 +53,44 @@ def get_instance(self) -> T:
return self.factory.call()


class SingletonInjectable[T](BaseInjectable[T]):
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class CachedInjectable[R, T](BaseInjectable[R, T], ABC):
__semaphore: AsyncContextManager[Any] = field(
default_factory=partial(create_semaphore, 1),
init=False,
hash=False,
)

async def aget_or_create[K](
self,
cache: MutableMapping[K, T],
key: K,
factory: Callable[..., Awaitable[T]],
) -> T:
async with self.__semaphore:
with suppress(KeyError):
return cache[key]

instance = await factory()
cache[key] = instance

return instance

def get_or_create[K](
self,
cache: MutableMapping[K, T],
key: K,
factory: Callable[..., T],
) -> T:
with suppress(KeyError):
return cache[key]

instance = factory()
cache[key] = instance
return instance


class SingletonInjectable[T](CachedInjectable[T, T]):
__slots__ = ("__dict__",)

__key: ClassVar[str] = "$instance"
Expand All @@ -66,32 +104,17 @@ def __cache(self) -> MutableMapping[str, Any]:
return self.__dict__

async def aget_instance(self) -> T:
cache = self.__cache

with suppress(KeyError):
return cache[self.__key]

instance = await self.factory.acall()
cache[self.__key] = instance
return instance
return await self.aget_or_create(self.__cache, self.__key, self.factory.acall)

def get_instance(self) -> T:
cache = self.__cache

with suppress(KeyError):
return cache[self.__key]

instance = self.factory.call()
cache[self.__key] = instance
return instance
return self.get_or_create(self.__cache, self.__key, self.factory.call)

def unlock(self) -> None:
self.__cache.pop(self.__key, None)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ScopedInjectable[R, T](Injectable[T], ABC):
factory: Caller[..., R]
class ScopedInjectable[R, T](CachedInjectable[R, T], ABC):
scope_name: str

@property
Expand All @@ -108,29 +131,20 @@ def build(self, scope: Scope) -> T:

async def aget_instance(self) -> T:
scope = self.get_scope()

with suppress(KeyError):
return scope.cache[self]

instance = await self.abuild(scope)
self.set_instance(instance, scope)
return instance
factory = partial(self.abuild, scope)
return await self.aget_or_create(scope.cache, self, factory)

def get_instance(self) -> T:
scope = self.get_scope()

with suppress(KeyError):
return scope.cache[self]

instance = self.build(scope)
self.set_instance(instance, scope)
return instance
factory = partial(self.build, scope)
return self.get_or_create(scope.cache, self, factory)

def get_scope(self) -> Scope:
return get_scope(self.scope_name)

def set_instance(self, instance: T, scope: Scope) -> None:
scope.cache[self] = instance
def setdefault(self, instance: T) -> T:
scope = self.get_scope()
return self.get_or_create(scope.cache, self, lambda: instance)

def unlock(self) -> None:
if self.is_locked:
Expand Down Expand Up @@ -174,7 +188,7 @@ def unlock(self) -> None:
scope.cache.pop(self, None)


@dataclass(repr=False, frozen=True, slots=True)
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ShouldBeInjectable[T](Injectable[T]):
cls: type[T]

Expand Down
3 changes: 2 additions & 1 deletion injection/_core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ async def all_ready(self) -> None:
if injectable.is_locked:
continue

await injectable.aget_instance()
with suppress(SkipInjectable):
await injectable.aget_instance()

def add_listener(self, listener: EventListener) -> Self:
self.__channel.add_listener(listener)
Expand Down
11 changes: 8 additions & 3 deletions injection/_core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]:


def get_active_scopes(name: str) -> tuple[Scope, ...]:
return tuple(__SCOPES[name].active_scopes)
state = __SCOPES.get(name)

if state is None:
return ()

return tuple(state.active_scopes)


def get_scope(name: str) -> Scope:
scope = __SCOPES[name].get_scope()
state = __SCOPES.get(name)

if scope is None:
if state is None or (scope := state.get_scope()) is None:
raise ScopeUndefinedError(
f"Scope `{name}` isn't defined in the current context."
)
Expand Down
6 changes: 3 additions & 3 deletions injection/_core/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Protocol, runtime_checkable

from injection._core.injectables import ScopedInjectable
from injection.exceptions import InjectionError


@runtime_checkable
Expand All @@ -19,6 +20,5 @@ class ScopedSlot[T](Slot[T]):
injectable: ScopedInjectable[Any, T]

def set(self, instance: T, /) -> None:
injectable = self.injectable
scope = injectable.get_scope()
injectable.set_instance(instance, scope)
if self.injectable.setdefault(instance) is not instance:
raise InjectionError("Slot already set.")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ bench = [
"types-tabulate",
]
dev = [
"anyio",
"hatch",
"mypy",
"ruff",
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from injection import Module, define_scope
from injection.exceptions import (
EmptySlotError,
InjectionError,
ModuleError,
ModuleLockError,
ModuleNotUsedError,
Expand Down Expand Up @@ -260,6 +261,22 @@ def test_reserve_scoped_slot_with_empty_raise_empty_slot_error(self, module):
with pytest.raises(EmptySlotError):
module.find_instance(SomeClass)

def test_reserve_scoped_slot_with_several_definitions_raise_injection_error(
self,
module,
):
scope_name = "test"
slot = module.reserve_scoped_slot(SomeClass, scope_name)

with define_scope(scope_name):
instance1 = SomeClass()
slot.set(instance1)

instance2 = SomeClass()

with pytest.raises(InjectionError):
slot.set(instance2)

"""
init_modules
"""
Expand Down
Loading