Skip to content

Commit 684424d

Browse files
authored
feat: ✨ Add threadsafe parameter to instance getters
1 parent 05c204d commit 684424d

12 files changed

Lines changed: 163 additions & 148 deletions

File tree

injection/__init__.pyi

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,13 @@ class Module:
217217
/,
218218
threadsafe: bool = ...,
219219
) -> Callable[..., Awaitable[T]]: ...
220-
async def afind_instance[T](self, cls: _InputType[T]) -> T: ...
221-
def find_instance[T](self, cls: _InputType[T]) -> T:
220+
async def afind_instance[T](
221+
self,
222+
cls: _InputType[T],
223+
*,
224+
threadsafe: bool = ...,
225+
) -> T: ...
226+
def find_instance[T](self, cls: _InputType[T], *, threadsafe: bool = ...) -> T:
222227
"""
223228
Function used to retrieve an instance associated with the type passed in
224229
parameter or an exception will be raised.
@@ -229,59 +234,66 @@ class Module:
229234
self,
230235
cls: _InputType[T],
231236
default: Default,
237+
*,
238+
threadsafe: bool = ...,
232239
) -> T | Default: ...
233240
@overload
234241
async def aget_instance[T](
235242
self,
236243
cls: _InputType[T],
237-
default: None = ...,
238-
) -> T | None: ...
244+
default: T = ...,
245+
*,
246+
threadsafe: bool = ...,
247+
) -> T: ...
239248
@overload
240249
def get_instance[T, Default](
241250
self,
242251
cls: _InputType[T],
243252
default: Default,
253+
*,
254+
threadsafe: bool = ...,
244255
) -> T | Default:
245256
"""
246257
Function used to retrieve an instance associated with the type passed in
247-
parameter or return `None`.
258+
parameter or return `NotImplemented`.
248259
"""
249260

250261
@overload
251262
def get_instance[T](
252263
self,
253264
cls: _InputType[T],
254-
default: None = ...,
255-
) -> T | None: ...
265+
default: T = ...,
266+
*,
267+
threadsafe: bool = ...,
268+
) -> T: ...
256269
@overload
257270
def aget_lazy_instance[T, Default](
258271
self,
259272
cls: _InputType[T],
260273
default: Default,
261274
*,
262-
cache: bool = ...,
275+
threadsafe: bool = ...,
263276
) -> Awaitable[T | Default]: ...
264277
@overload
265278
def aget_lazy_instance[T](
266279
self,
267280
cls: _InputType[T],
268-
default: None = ...,
281+
default: T = ...,
269282
*,
270-
cache: bool = ...,
271-
) -> Awaitable[T | None]: ...
283+
threadsafe: bool = ...,
284+
) -> Awaitable[T]: ...
272285
@overload
273286
def get_lazy_instance[T, Default](
274287
self,
275288
cls: _InputType[T],
276289
default: Default,
277290
*,
278-
cache: bool = ...,
291+
threadsafe: bool = ...,
279292
) -> _Invertible[T | Default]:
280293
"""
281294
Function used to retrieve an instance associated with the type passed in
282-
parameter or `None`. Return a `Invertible` object. To access the instance
295+
parameter or `NotImplemented`. Return a `Invertible` object. To access the instance
283296
contained in an invertible object, simply use a wavy line (~).
284-
With `cache=True`, the instance retrieved will always be the same.
285297
286298
Example: instance = ~lazy_instance
287299
"""
@@ -290,10 +302,10 @@ class Module:
290302
def get_lazy_instance[T](
291303
self,
292304
cls: _InputType[T],
293-
default: None = ...,
305+
default: T = ...,
294306
*,
295-
cache: bool = ...,
296-
) -> _Invertible[T | None]: ...
307+
threadsafe: bool = ...,
308+
) -> _Invertible[T]: ...
297309
def init_modules(self, *modules: Module) -> Self:
298310
"""
299311
Function to clean modules in use and to use those passed as parameters.

injection/_core/common/asynchronous.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
5252

5353
def create_semaphore(value: int) -> AsyncContextManager[Any]:
5454
return anyio.Semaphore(value)
55+
5556
except ImportError: # pragma: no cover
5657
import asyncio
5758

injection/_core/common/lazy.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
1+
from collections.abc import Callable, Iterator
22
from functools import partial
33

4-
from injection._core.common.asynchronous import SimpleAwaitable
54
from injection._core.common.invertible import Invertible, SimpleInvertible
65

76

@@ -18,19 +17,6 @@ def cache() -> Iterator[T]:
1817
return SimpleInvertible(getter)
1918

2019

21-
def alazy[T](factory: Callable[..., Awaitable[T]]) -> Awaitable[T]:
22-
async def cache() -> AsyncIterator[T]:
23-
nonlocal factory
24-
value = await factory()
25-
del factory
26-
27-
while True:
28-
yield value
29-
30-
getter = partial(anext, cache())
31-
return SimpleAwaitable(getter)
32-
33-
3420
class Lazy[T](Invertible[T]):
3521
__slots__ = ("__invertible", "__is_set")
3622

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from contextlib import nullcontext
2+
from threading import RLock
3+
from typing import Any, ContextManager
4+
5+
6+
def get_lock(threadsafe: bool) -> ContextManager[Any]:
7+
return RLock() if threadsafe else nullcontext()

injection/_core/module.py

Lines changed: 61 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import threading
43
from abc import ABC, abstractmethod
54
from collections import OrderedDict, deque
65
from collections.abc import (
@@ -14,7 +13,7 @@
1413
Iterator,
1514
Mapping,
1615
)
17-
from contextlib import asynccontextmanager, contextmanager, nullcontext, suppress
16+
from contextlib import asynccontextmanager, contextmanager, suppress
1817
from dataclasses import dataclass, field
1918
from enum import StrEnum
2019
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
@@ -51,7 +50,8 @@
5150
from injection._core.common.event import Event, EventChannel, EventListener
5251
from injection._core.common.invertible import Invertible, SimpleInvertible
5352
from injection._core.common.key import new_short_key
54-
from injection._core.common.lazy import Lazy, alazy, lazy
53+
from injection._core.common.lazy import Lazy, lazy
54+
from injection._core.common.threading import get_lock
5555
from injection._core.common.type import (
5656
InputType,
5757
TypeInfo,
@@ -617,35 +617,48 @@ def make_async_factory[T](
617617
)
618618
return factory.__inject_metadata__.acall
619619

620-
async def afind_instance[T](self, cls: InputType[T]) -> T:
621-
injectable = self[cls]
622-
return await injectable.aget_instance()
620+
async def afind_instance[T](
621+
self,
622+
cls: InputType[T],
623+
*,
624+
threadsafe: bool = False,
625+
) -> T:
626+
with get_lock(threadsafe):
627+
injectable = self[cls]
628+
return await injectable.aget_instance()
623629

624-
def find_instance[T](self, cls: InputType[T]) -> T:
625-
injectable = self[cls]
626-
return injectable.get_instance()
630+
def find_instance[T](self, cls: InputType[T], *, threadsafe: bool = False) -> T:
631+
with get_lock(threadsafe):
632+
injectable = self[cls]
633+
return injectable.get_instance()
627634

628635
@overload
629636
async def aget_instance[T, Default](
630637
self,
631638
cls: InputType[T],
632639
default: Default,
640+
*,
641+
threadsafe: bool = ...,
633642
) -> T | Default: ...
634643

635644
@overload
636645
async def aget_instance[T](
637646
self,
638647
cls: InputType[T],
639-
default: None = ...,
640-
) -> T | None: ...
648+
default: T = ...,
649+
*,
650+
threadsafe: bool = ...,
651+
) -> T: ...
641652

642653
async def aget_instance[T, Default](
643654
self,
644655
cls: InputType[T],
645-
default: Default | None = None,
646-
) -> T | Default | None:
656+
default: Default = NotImplemented,
657+
*,
658+
threadsafe: bool = False,
659+
) -> T | Default:
647660
try:
648-
return await self.afind_instance(cls)
661+
return await self.afind_instance(cls, threadsafe=threadsafe)
649662
except (KeyError, SkipInjectable):
650663
return default
651664

@@ -654,22 +667,28 @@ def get_instance[T, Default](
654667
self,
655668
cls: InputType[T],
656669
default: Default,
670+
*,
671+
threadsafe: bool = ...,
657672
) -> T | Default: ...
658673

659674
@overload
660675
def get_instance[T](
661676
self,
662677
cls: InputType[T],
663-
default: None = ...,
664-
) -> T | None: ...
678+
default: T = ...,
679+
*,
680+
threadsafe: bool = ...,
681+
) -> T: ...
665682

666683
def get_instance[T, Default](
667684
self,
668685
cls: InputType[T],
669-
default: Default | None = None,
670-
) -> T | Default | None:
686+
default: Default = NotImplemented,
687+
*,
688+
threadsafe: bool = False,
689+
) -> T | Default:
671690
try:
672-
return self.find_instance(cls)
691+
return self.find_instance(cls, threadsafe=threadsafe)
673692
except (KeyError, SkipInjectable):
674693
return default
675694

@@ -679,29 +698,29 @@ def aget_lazy_instance[T, Default](
679698
cls: InputType[T],
680699
default: Default,
681700
*,
682-
cache: bool = ...,
701+
threadsafe: bool = ...,
683702
) -> Awaitable[T | Default]: ...
684703

685704
@overload
686705
def aget_lazy_instance[T](
687706
self,
688707
cls: InputType[T],
689-
default: None = ...,
708+
default: T = ...,
690709
*,
691-
cache: bool = ...,
692-
) -> Awaitable[T | None]: ...
710+
threadsafe: bool = ...,
711+
) -> Awaitable[T]: ...
693712

694713
def aget_lazy_instance[T, Default](
695714
self,
696715
cls: InputType[T],
697-
default: Default | None = None,
716+
default: Default = NotImplemented,
698717
*,
699-
cache: bool = False,
700-
) -> Awaitable[T | Default | None]:
701-
if cache:
702-
return alazy(lambda: self.aget_instance(cls, default))
703-
704-
function = self.make_injected_function(lambda instance=default: instance)
718+
threadsafe: bool = False,
719+
) -> Awaitable[T | Default]:
720+
function = self.make_injected_function(
721+
lambda instance=default: instance,
722+
threadsafe=threadsafe,
723+
)
705724
metadata = function.__inject_metadata__.set_owner(cls)
706725
return SimpleAwaitable(metadata.acall)
707726

@@ -711,29 +730,29 @@ def get_lazy_instance[T, Default](
711730
cls: InputType[T],
712731
default: Default,
713732
*,
714-
cache: bool = ...,
733+
threadsafe: bool = ...,
715734
) -> Invertible[T | Default]: ...
716735

717736
@overload
718737
def get_lazy_instance[T](
719738
self,
720739
cls: InputType[T],
721-
default: None = ...,
740+
default: T = ...,
722741
*,
723-
cache: bool = ...,
724-
) -> Invertible[T | None]: ...
742+
threadsafe: bool = ...,
743+
) -> Invertible[T]: ...
725744

726745
def get_lazy_instance[T, Default](
727746
self,
728747
cls: InputType[T],
729-
default: Default | None = None,
748+
default: Default = NotImplemented,
730749
*,
731-
cache: bool = False,
732-
) -> Invertible[T | Default | None]:
733-
if cache:
734-
return lazy(lambda: self.get_instance(cls, default))
735-
736-
function = self.make_injected_function(lambda instance=default: instance)
750+
threadsafe: bool = False,
751+
) -> Invertible[T | Default]:
752+
function = self.make_injected_function(
753+
lambda instance=default: instance,
754+
threadsafe=threadsafe,
755+
)
737756
metadata = function.__inject_metadata__.set_owner(cls)
738757
return SimpleInvertible(metadata.call)
739758

@@ -996,7 +1015,7 @@ class InjectMetadata[**P, T](Caller[P, T], EventListener):
9961015

9971016
def __init__(self, wrapped: Callable[P, T], /, threadsafe: bool) -> None:
9981017
self.__dependencies = Dependencies.empty()
999-
self.__lock = threading.RLock() if threadsafe else nullcontext()
1018+
self.__lock = get_lock(threadsafe)
10001019
self.__owner = None
10011020
self.__tasks = deque()
10021021
self.__wrapped = wrapped

0 commit comments

Comments
 (0)