Skip to content

Commit 3709268

Browse files
authored
fix: 🐛 Protecting concurrent cache access with async recipes
1 parent 26f5991 commit 3709268

8 files changed

Lines changed: 254 additions & 186 deletions

File tree

injection/_core/common/asynchronous.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import abstractmethod
22
from collections.abc import Awaitable, Callable, Generator
33
from dataclasses import dataclass
4-
from typing import Any, NoReturn, Protocol, runtime_checkable
4+
from typing import Any, AsyncContextManager, NoReturn, Protocol, runtime_checkable
55

66

77
@dataclass(repr=False, eq=False, frozen=True, slots=True)
@@ -45,3 +45,15 @@ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
4545

4646
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
4747
return self.callable(*args, **kwargs)
48+
49+
50+
try:
51+
import anyio
52+
53+
def create_semaphore(value: int) -> AsyncContextManager[Any]:
54+
return anyio.Semaphore(value)
55+
except ImportError: # pragma: no cover
56+
import asyncio
57+
58+
def create_semaphore(value: int) -> AsyncContextManager[Any]:
59+
return asyncio.Semaphore(value)

injection/_core/injectables.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
2-
from collections.abc import MutableMapping
2+
from collections.abc import Awaitable, Callable, MutableMapping
33
from contextlib import suppress
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
5+
from functools import partial
56
from typing import (
67
Any,
78
AsyncContextManager,
@@ -12,7 +13,7 @@
1213
runtime_checkable,
1314
)
1415

15-
from injection._core.common.asynchronous import Caller
16+
from injection._core.common.asynchronous import Caller, create_semaphore
1617
from injection._core.scope import Scope, get_active_scopes, get_scope
1718
from injection.exceptions import InjectionError
1819

@@ -37,12 +38,12 @@ def get_instance(self) -> T:
3738
raise NotImplementedError
3839

3940

40-
@dataclass(repr=False, frozen=True, slots=True)
41-
class BaseInjectable[T](Injectable[T], ABC):
42-
factory: Caller[..., T]
41+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
42+
class BaseInjectable[R, T](Injectable[T], ABC):
43+
factory: Caller[..., R]
4344

4445

45-
class SimpleInjectable[T](BaseInjectable[T]):
46+
class SimpleInjectable[T](BaseInjectable[T, T]):
4647
__slots__ = ()
4748

4849
async def aget_instance(self) -> T:
@@ -52,7 +53,44 @@ def get_instance(self) -> T:
5253
return self.factory.call()
5354

5455

55-
class SingletonInjectable[T](BaseInjectable[T]):
56+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
57+
class CachedInjectable[R, T](BaseInjectable[R, T], ABC):
58+
__semaphore: AsyncContextManager[Any] = field(
59+
default_factory=partial(create_semaphore, 1),
60+
init=False,
61+
hash=False,
62+
)
63+
64+
async def aget_or_create[K](
65+
self,
66+
cache: MutableMapping[K, T],
67+
key: K,
68+
factory: Callable[..., Awaitable[T]],
69+
) -> T:
70+
async with self.__semaphore:
71+
with suppress(KeyError):
72+
return cache[key]
73+
74+
instance = await factory()
75+
cache[key] = instance
76+
77+
return instance
78+
79+
def get_or_create[K](
80+
self,
81+
cache: MutableMapping[K, T],
82+
key: K,
83+
factory: Callable[..., T],
84+
) -> T:
85+
with suppress(KeyError):
86+
return cache[key]
87+
88+
instance = factory()
89+
cache[key] = instance
90+
return instance
91+
92+
93+
class SingletonInjectable[T](CachedInjectable[T, T]):
5694
__slots__ = ("__dict__",)
5795

5896
__key: ClassVar[str] = "$instance"
@@ -66,32 +104,17 @@ def __cache(self) -> MutableMapping[str, Any]:
66104
return self.__dict__
67105

68106
async def aget_instance(self) -> T:
69-
cache = self.__cache
70-
71-
with suppress(KeyError):
72-
return cache[self.__key]
73-
74-
instance = await self.factory.acall()
75-
cache[self.__key] = instance
76-
return instance
107+
return await self.aget_or_create(self.__cache, self.__key, self.factory.acall)
77108

78109
def get_instance(self) -> T:
79-
cache = self.__cache
80-
81-
with suppress(KeyError):
82-
return cache[self.__key]
83-
84-
instance = self.factory.call()
85-
cache[self.__key] = instance
86-
return instance
110+
return self.get_or_create(self.__cache, self.__key, self.factory.call)
87111

88112
def unlock(self) -> None:
89113
self.__cache.pop(self.__key, None)
90114

91115

92116
@dataclass(repr=False, eq=False, frozen=True, slots=True)
93-
class ScopedInjectable[R, T](Injectable[T], ABC):
94-
factory: Caller[..., R]
117+
class ScopedInjectable[R, T](CachedInjectable[R, T], ABC):
95118
scope_name: str
96119

97120
@property
@@ -108,29 +131,20 @@ def build(self, scope: Scope) -> T:
108131

109132
async def aget_instance(self) -> T:
110133
scope = self.get_scope()
111-
112-
with suppress(KeyError):
113-
return scope.cache[self]
114-
115-
instance = await self.abuild(scope)
116-
self.set_instance(instance, scope)
117-
return instance
134+
factory = partial(self.abuild, scope)
135+
return await self.aget_or_create(scope.cache, self, factory)
118136

119137
def get_instance(self) -> T:
120138
scope = self.get_scope()
121-
122-
with suppress(KeyError):
123-
return scope.cache[self]
124-
125-
instance = self.build(scope)
126-
self.set_instance(instance, scope)
127-
return instance
139+
factory = partial(self.build, scope)
140+
return self.get_or_create(scope.cache, self, factory)
128141

129142
def get_scope(self) -> Scope:
130143
return get_scope(self.scope_name)
131144

132-
def set_instance(self, instance: T, scope: Scope) -> None:
133-
scope.cache[self] = instance
145+
def setdefault(self, instance: T) -> T:
146+
scope = self.get_scope()
147+
return self.get_or_create(scope.cache, self, lambda: instance)
134148

135149
def unlock(self) -> None:
136150
if self.is_locked:
@@ -174,7 +188,7 @@ def unlock(self) -> None:
174188
scope.cache.pop(self, None)
175189

176190

177-
@dataclass(repr=False, frozen=True, slots=True)
191+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
178192
class ShouldBeInjectable[T](Injectable[T]):
179193
cls: type[T]
180194

injection/_core/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ async def all_ready(self) -> None:
300300
if injectable.is_locked:
301301
continue
302302

303-
await injectable.aget_instance()
303+
with suppress(SkipInjectable):
304+
await injectable.aget_instance()
304305

305306
def add_listener(self, listener: EventListener) -> Self:
306307
self.__channel.add_listener(listener)

injection/_core/scope.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,18 @@ def define_scope(name: str, *, shared: bool = False) -> Iterator[None]:
9898

9999

100100
def get_active_scopes(name: str) -> tuple[Scope, ...]:
101-
return tuple(__SCOPES[name].active_scopes)
101+
state = __SCOPES.get(name)
102+
103+
if state is None:
104+
return ()
105+
106+
return tuple(state.active_scopes)
102107

103108

104109
def get_scope(name: str) -> Scope:
105-
scope = __SCOPES[name].get_scope()
110+
state = __SCOPES.get(name)
106111

107-
if scope is None:
112+
if state is None or (scope := state.get_scope()) is None:
108113
raise ScopeUndefinedError(
109114
f"Scope `{name}` isn't defined in the current context."
110115
)

injection/_core/slots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Protocol, runtime_checkable
44

55
from injection._core.injectables import ScopedInjectable
6+
from injection.exceptions import InjectionError
67

78

89
@runtime_checkable
@@ -19,6 +20,5 @@ class ScopedSlot[T](Slot[T]):
1920
injectable: ScopedInjectable[Any, T]
2021

2122
def set(self, instance: T, /) -> None:
22-
injectable = self.injectable
23-
scope = injectable.get_scope()
24-
injectable.set_instance(instance, scope)
23+
if self.injectable.setdefault(instance) is not instance:
24+
raise InjectionError("Slot already set.")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ bench = [
99
"types-tabulate",
1010
]
1111
dev = [
12+
"anyio",
1213
"hatch",
1314
"mypy",
1415
"ruff",

tests/core/test_module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from injection import Module, define_scope
88
from injection.exceptions import (
99
EmptySlotError,
10+
InjectionError,
1011
ModuleError,
1112
ModuleLockError,
1213
ModuleNotUsedError,
@@ -260,6 +261,22 @@ def test_reserve_scoped_slot_with_empty_raise_empty_slot_error(self, module):
260261
with pytest.raises(EmptySlotError):
261262
module.find_instance(SomeClass)
262263

264+
def test_reserve_scoped_slot_with_several_definitions_raise_injection_error(
265+
self,
266+
module,
267+
):
268+
scope_name = "test"
269+
slot = module.reserve_scoped_slot(SomeClass, scope_name)
270+
271+
with define_scope(scope_name):
272+
instance1 = SomeClass()
273+
slot.set(instance1)
274+
275+
instance2 = SomeClass()
276+
277+
with pytest.raises(InjectionError):
278+
slot.set(instance2)
279+
263280
"""
264281
init_modules
265282
"""

0 commit comments

Comments
 (0)