Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions injection/_core/injectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
runtime_checkable,
)

from injection._core.common.asynchronous import Caller, create_semaphore
from injection._core.common.asynchronous import Caller
from injection._core.common.asynchronous import (
create_semaphore as _create_async_semaphore,
)
from injection._core.scope import Scope, get_active_scopes, get_scope
from injection.exceptions import InjectionError

Expand All @@ -39,12 +42,8 @@ def get_instance(self) -> T:


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class BaseInjectable[R, T](Injectable[T], ABC):
factory: Caller[..., R]


class SimpleInjectable[T](BaseInjectable[T, T]):
__slots__ = ()
class SimpleInjectable[T](Injectable[T]):
factory: Caller[..., T]

async def aget_instance(self) -> T:
return await self.factory.acall()
Expand All @@ -53,13 +52,13 @@ def get_instance(self) -> T:
return self.factory.call()


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class CachedInjectable[R, T](BaseInjectable[R, T], ABC):
__semaphore: AsyncContextManager[Any] = field(
default_factory=partial(create_semaphore, 1),
init=False,
hash=False,
)
class CacheLogic[T]:
__slots__ = ("__semaphore",)

__semaphore: AsyncContextManager[Any]

def __init__(self) -> None:
self.__semaphore = _create_async_semaphore(1)

async def aget_or_create[K](
self,
Expand Down Expand Up @@ -90,32 +89,37 @@ def get_or_create[K](
return instance


class SingletonInjectable[T](CachedInjectable[T, T]):
__slots__ = ("__dict__",)
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class SingletonInjectable[T](Injectable[T]):
factory: Caller[..., T]
cache: MutableMapping[str, T] = field(default_factory=dict)
logic: CacheLogic[T] = field(default_factory=CacheLogic)

__key: ClassVar[str] = "$instance"

@property
def is_locked(self) -> bool:
return self.__key in self.__cache

@property
def __cache(self) -> MutableMapping[str, Any]:
return self.__dict__
return self.__key in self.cache

async def aget_instance(self) -> T:
return await self.aget_or_create(self.__cache, self.__key, self.factory.acall)
return await self.logic.aget_or_create(
self.cache,
self.__key,
self.factory.acall,
)

def get_instance(self) -> T:
return self.get_or_create(self.__cache, self.__key, self.factory.call)
return self.logic.get_or_create(self.cache, self.__key, self.factory.call)

def unlock(self) -> None:
self.__cache.pop(self.__key, None)
self.cache.pop(self.__key, None)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ScopedInjectable[R, T](CachedInjectable[R, T], ABC):
class ScopedInjectable[R, T](Injectable[T], ABC):
factory: Caller[..., R]
scope_name: str
logic: CacheLogic[T] = field(default_factory=CacheLogic)

@property
def is_locked(self) -> bool:
Expand All @@ -130,26 +134,26 @@ def build(self, scope: Scope) -> T:
raise NotImplementedError

async def aget_instance(self) -> T:
scope = self.get_scope()
scope = self.__get_scope()
factory = partial(self.abuild, scope)
return await self.aget_or_create(scope.cache, self, factory)
return await self.logic.aget_or_create(scope.cache, self, factory)

def get_instance(self) -> T:
scope = self.get_scope()
scope = self.__get_scope()
factory = partial(self.build, scope)
return self.get_or_create(scope.cache, self, factory)

def get_scope(self) -> Scope:
return get_scope(self.scope_name)
return self.logic.get_or_create(scope.cache, self, factory)

def setdefault(self, instance: T) -> T:
scope = self.get_scope()
return self.get_or_create(scope.cache, self, lambda: instance)
scope = self.__get_scope()
return self.logic.get_or_create(scope.cache, self, lambda: instance)

def unlock(self) -> None:
if self.is_locked:
raise RuntimeError(f"To unlock, close the `{self.scope_name}` scope.")

def __get_scope(self) -> Scope:
return get_scope(self.scope_name)


class AsyncCMScopedInjectable[T](ScopedInjectable[AsyncContextManager[T], T]):
__slots__ = ()
Expand Down
Loading