11from abc import ABC , abstractmethod
2- from collections .abc import MutableMapping
2+ from collections .abc import Awaitable , Callable , MutableMapping
33from contextlib import suppress
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
5+ from functools import partial
56from typing import (
67 Any ,
78 AsyncContextManager ,
1213 runtime_checkable ,
1314)
1415
15- from injection ._core .common .asynchronous import Caller
16+ from injection ._core .common .asynchronous import Caller , create_semaphore
1617from injection ._core .scope import Scope , get_active_scopes , get_scope
1718from injection .exceptions import InjectionError
1819
@@ -37,12 +38,12 @@ def get_instance(self) -> T:
3738 raise NotImplementedError
3839
3940
40- @dataclass (repr = False , frozen = True , slots = True )
41- class BaseInjectable [T ](Injectable [T ], ABC ):
42- factory : Caller [..., T ]
41+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
42+ class BaseInjectable [R , T ](Injectable [T ], ABC ):
43+ factory : Caller [..., R ]
4344
4445
45- class SimpleInjectable [T ](BaseInjectable [T ]):
46+ class SimpleInjectable [T ](BaseInjectable [T , T ]):
4647 __slots__ = ()
4748
4849 async def aget_instance (self ) -> T :
@@ -52,7 +53,44 @@ def get_instance(self) -> T:
5253 return self .factory .call ()
5354
5455
55- class SingletonInjectable [T ](BaseInjectable [T ]):
56+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
57+ class CachedInjectable [R , T ](BaseInjectable [R , T ], ABC ):
58+ __semaphore : AsyncContextManager [Any ] = field (
59+ default_factory = partial (create_semaphore , 1 ),
60+ init = False ,
61+ hash = False ,
62+ )
63+
64+ async def aget_or_create [K ](
65+ self ,
66+ cache : MutableMapping [K , T ],
67+ key : K ,
68+ factory : Callable [..., Awaitable [T ]],
69+ ) -> T :
70+ async with self .__semaphore :
71+ with suppress (KeyError ):
72+ return cache [key ]
73+
74+ instance = await factory ()
75+ cache [key ] = instance
76+
77+ return instance
78+
79+ def get_or_create [K ](
80+ self ,
81+ cache : MutableMapping [K , T ],
82+ key : K ,
83+ factory : Callable [..., T ],
84+ ) -> T :
85+ with suppress (KeyError ):
86+ return cache [key ]
87+
88+ instance = factory ()
89+ cache [key ] = instance
90+ return instance
91+
92+
93+ class SingletonInjectable [T ](CachedInjectable [T , T ]):
5694 __slots__ = ("__dict__" ,)
5795
5896 __key : ClassVar [str ] = "$instance"
@@ -66,32 +104,17 @@ def __cache(self) -> MutableMapping[str, Any]:
66104 return self .__dict__
67105
68106 async def aget_instance (self ) -> T :
69- cache = self .__cache
70-
71- with suppress (KeyError ):
72- return cache [self .__key ]
73-
74- instance = await self .factory .acall ()
75- cache [self .__key ] = instance
76- return instance
107+ return await self .aget_or_create (self .__cache , self .__key , self .factory .acall )
77108
78109 def get_instance (self ) -> T :
79- cache = self .__cache
80-
81- with suppress (KeyError ):
82- return cache [self .__key ]
83-
84- instance = self .factory .call ()
85- cache [self .__key ] = instance
86- return instance
110+ return self .get_or_create (self .__cache , self .__key , self .factory .call )
87111
88112 def unlock (self ) -> None :
89113 self .__cache .pop (self .__key , None )
90114
91115
92116@dataclass (repr = False , eq = False , frozen = True , slots = True )
93- class ScopedInjectable [R , T ](Injectable [T ], ABC ):
94- factory : Caller [..., R ]
117+ class ScopedInjectable [R , T ](CachedInjectable [R , T ], ABC ):
95118 scope_name : str
96119
97120 @property
@@ -108,29 +131,20 @@ def build(self, scope: Scope) -> T:
108131
109132 async def aget_instance (self ) -> T :
110133 scope = self .get_scope ()
111-
112- with suppress (KeyError ):
113- return scope .cache [self ]
114-
115- instance = await self .abuild (scope )
116- self .set_instance (instance , scope )
117- return instance
134+ factory = partial (self .abuild , scope )
135+ return await self .aget_or_create (scope .cache , self , factory )
118136
119137 def get_instance (self ) -> T :
120138 scope = self .get_scope ()
121-
122- with suppress (KeyError ):
123- return scope .cache [self ]
124-
125- instance = self .build (scope )
126- self .set_instance (instance , scope )
127- return instance
139+ factory = partial (self .build , scope )
140+ return self .get_or_create (scope .cache , self , factory )
128141
129142 def get_scope (self ) -> Scope :
130143 return get_scope (self .scope_name )
131144
132- def set_instance (self , instance : T , scope : Scope ) -> None :
133- scope .cache [self ] = instance
145+ def setdefault (self , instance : T ) -> T :
146+ scope = self .get_scope ()
147+ return self .get_or_create (scope .cache , self , lambda : instance )
134148
135149 def unlock (self ) -> None :
136150 if self .is_locked :
@@ -174,7 +188,7 @@ def unlock(self) -> None:
174188 scope .cache .pop (self , None )
175189
176190
177- @dataclass (repr = False , frozen = True , slots = True )
191+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
178192class ShouldBeInjectable [T ](Injectable [T ]):
179193 cls : type [T ]
180194
0 commit comments