Skip to content

Commit 53bacb3

Browse files
authored
feat: ✨️ Add threadsafe parameter to @inject
1 parent 84ee0ce commit 53bacb3

5 files changed

Lines changed: 114 additions & 52 deletions

File tree

injection/__init__.pyi

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,19 @@ class Module:
7474
def __contains__(self, cls: _InputType[Any], /) -> bool: ...
7575
@property
7676
def is_locked(self) -> bool: ...
77-
def inject[**P, T](self, wrapped: Callable[P, T] = ..., /) -> Any:
77+
def inject[**P, T](
78+
self,
79+
wrapped: Callable[P, T] = ...,
80+
/,
81+
*,
82+
threadsafe: bool = ...,
83+
) -> Any:
7884
"""
7985
Decorator applicable to a class or function. Inject function dependencies using
8086
parameter type annotations. If applied to a class, the dependencies resolved
8187
will be those of the `__init__` method.
88+
89+
With `threadsafe=True`, the injection logic is wrapped in a `threading.Lock`.
8290
"""
8391

8492
def injectable[**P, T](
@@ -166,6 +174,7 @@ class Module:
166174
self,
167175
wrapped: Callable[P, T],
168176
/,
177+
threadsafe: bool = ...,
169178
) -> Callable[P, T]: ...
170179
async def afind_instance[T](self, cls: _InputType[T]) -> T: ...
171180
def find_instance[T](self, cls: _InputType[T]) -> T:
@@ -295,6 +304,8 @@ class Module:
295304
Function to unlock the module by deleting cached instances of singletons.
296305
"""
297306

307+
@contextmanager
308+
def load_profile(self, *names: str) -> Iterator[None]: ...
298309
async def all_ready(self) -> None: ...
299310
def add_logger(self, logger: Logger) -> Self: ...
300311
@classmethod

injection/_core/module.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Iterator,
1212
Mapping,
1313
)
14-
from contextlib import asynccontextmanager, contextmanager, suppress
14+
from contextlib import asynccontextmanager, contextmanager, nullcontext, suppress
1515
from dataclasses import dataclass, field
1616
from enum import StrEnum
1717
from functools import partialmethod, singledispatchmethod, update_wrapper
@@ -25,6 +25,7 @@
2525
)
2626
from inspect import signature as inspect_signature
2727
from logging import Logger, getLogger
28+
from threading import Lock
2829
from types import MethodType
2930
from typing import (
3031
Any,
@@ -542,13 +543,19 @@ def set_constant[T](
542543
)
543544
return self
544545

545-
def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /) -> Any:
546+
def inject[**P, T](
547+
self,
548+
wrapped: Callable[P, T] | None = None,
549+
/,
550+
*,
551+
threadsafe: bool = False,
552+
) -> Any:
546553
def decorator(wp: Callable[P, T]) -> Callable[P, T]:
547554
if isclass(wp):
548-
wp.__init__ = self.inject(wp.__init__)
555+
wp.__init__ = self.inject(wp.__init__, threadsafe=threadsafe)
549556
return wp
550557

551-
return self.make_injected_function(wp)
558+
return self.make_injected_function(wp, threadsafe)
552559

553560
return decorator(wrapped) if wrapped else decorator
554561

@@ -557,17 +564,19 @@ def make_injected_function[**P, T](
557564
self,
558565
wrapped: Callable[P, T],
559566
/,
567+
threadsafe: bool = ...,
560568
) -> SyncInjectedFunction[P, T]: ...
561569

562570
@overload
563571
def make_injected_function[**P, T](
564572
self,
565573
wrapped: Callable[P, Awaitable[T]],
566574
/,
575+
threadsafe: bool = ...,
567576
) -> AsyncInjectedFunction[P, T]: ...
568577

569-
def make_injected_function(self, wrapped, /): # type: ignore[no-untyped-def]
570-
metadata = InjectMetadata(wrapped)
578+
def make_injected_function(self, wrapped, /, threadsafe=False): # type: ignore[no-untyped-def]
579+
metadata = InjectMetadata(wrapped, threadsafe)
571580

572581
@metadata.task
573582
def listen() -> None:
@@ -753,6 +762,23 @@ def unlock(self) -> Self:
753762

754763
return self
755764

765+
def load_profile(self, *names: str) -> ContextManager[None]:
766+
modules = tuple(self.from_name(name) for name in names)
767+
768+
for module in modules:
769+
module.unlock()
770+
771+
self.unlock().init_modules(*modules)
772+
773+
del module, modules
774+
775+
@contextmanager
776+
def cleaner() -> Iterator[None]:
777+
yield
778+
self.unlock().init_modules()
779+
780+
return cleaner()
781+
756782
async def all_ready(self) -> None:
757783
for broker in self.__brokers:
758784
await broker.all_ready()
@@ -913,20 +939,23 @@ class Arguments(NamedTuple):
913939
class InjectMetadata[**P, T](Caller[P, T], EventListener):
914940
__slots__ = (
915941
"__dependencies",
942+
"__lock",
916943
"__owner",
917944
"__signature",
918945
"__tasks",
919946
"__wrapped",
920947
)
921948

922949
__dependencies: Dependencies
950+
__lock: ContextManager[Any]
923951
__owner: type | None
924952
__signature: Signature
925953
__tasks: deque[Callable[..., Any]]
926954
__wrapped: Callable[P, T]
927955

928-
def __init__(self, wrapped: Callable[P, T], /) -> None:
956+
def __init__(self, wrapped: Callable[P, T], /, threadsafe: bool) -> None:
929957
self.__dependencies = Dependencies.empty()
958+
self.__lock = Lock() if threadsafe else nullcontext()
930959
self.__owner = None
931960
self.__tasks = deque()
932961
self.__wrapped = wrapped
@@ -961,13 +990,17 @@ def bind(
961990
return self.__bind(args, kwargs, additional_arguments)
962991

963992
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
964-
self.__run_tasks()
965-
arguments = await self.abind(args, kwargs)
993+
with self.__lock:
994+
self.__run_tasks()
995+
arguments = await self.abind(args, kwargs)
996+
966997
return self.wrapped(*arguments.args, **arguments.kwargs)
967998

968999
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
969-
self.__run_tasks()
970-
arguments = self.bind(args, kwargs)
1000+
with self.__lock:
1001+
self.__run_tasks()
1002+
arguments = self.bind(args, kwargs)
1003+
9711004
return self.wrapped(*arguments.args, **arguments.kwargs)
9721005

9731006
def set_owner(self, owner: type) -> Self:

injection/_core/scope.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@ class _ScopeState:
3434
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
3535
init=False,
3636
)
37-
__references: set[Scope] = field(
38-
default_factory=set,
37+
__default: Scope | None = field(
38+
default=None,
3939
init=False,
4040
)
41-
__shared_value: Scope | None = field(
42-
default=None,
41+
__references: set[Scope] = field(
42+
default_factory=set,
4343
init=False,
4444
)
4545

4646
@property
4747
def active_scopes(self) -> Iterator[Scope]:
4848
yield from self.__references
4949

50-
if shared_value := self.__shared_value:
51-
yield shared_value
50+
if default := self.__default:
51+
yield default
5252

5353
@contextmanager
5454
def bind_contextual_scope(self, scope: Scope) -> Iterator[None]:
@@ -69,15 +69,15 @@ def bind_shared_scope(self, scope: Scope) -> Iterator[None]:
6969
"are defined on the same name."
7070
)
7171

72-
self.__shared_value = scope
72+
self.__default = scope
7373

7474
try:
7575
yield
7676
finally:
77-
self.__shared_value = None
77+
self.__default = None
7878

7979
def get_scope(self) -> Scope | None:
80-
return self.__context_var.get(self.__shared_value)
80+
return self.__context_var.get(self.__default)
8181

8282

8383
__SCOPES: Final[defaultdict[str, _ScopeState]] = defaultdict(_ScopeState)
@@ -125,11 +125,8 @@ def _bind_scope(name: str, scope: Scope, shared: bool) -> Iterator[None]:
125125
state.bind_shared_scope(scope) if shared else state.bind_contextual_scope(scope)
126126
)
127127

128-
try:
129-
with strategy:
130-
yield
131-
finally:
132-
scope.cache.clear()
128+
with strategy:
129+
yield
133130

134131

135132
@runtime_checkable

injection/utils.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections.abc import Callable, Iterable, Iterator
2-
from contextlib import contextmanager
1+
from collections.abc import Callable, Collection, Iterator
32
from importlib import import_module
43
from importlib.util import find_spec
54
from pkgutil import walk_packages
@@ -18,26 +17,12 @@ def load_profile(*names: str) -> ContextManager[None]:
1817
A profile name is equivalent to an injection module name.
1918
"""
2019

21-
modules = tuple(mod(module_name) for module_name in names)
22-
23-
for module in modules:
24-
module.unlock()
25-
26-
target = mod().unlock().init_modules(*modules)
27-
28-
del module, modules
29-
30-
@contextmanager
31-
def cleaner() -> Iterator[None]:
32-
yield
33-
target.unlock().init_modules()
34-
35-
return cleaner()
20+
return mod().load_profile(*names)
3621

3722

3823
def load_modules_with_keywords(
3924
*packages: PythonModule | str,
40-
keywords: Iterable[str] | None = None,
25+
keywords: Collection[str] | None = None,
4126
) -> dict[str, PythonModule]:
4227
"""
4328
Function to import modules from a Python package if one of the keywords is contained in the Python script.
@@ -54,16 +39,16 @@ def load_modules_with_keywords(
5439
f"import {injection_package_name}",
5540
)
5641

57-
b_keywords = tuple(keyword.encode() for keyword in keywords)
58-
5942
def predicate(module_name: str) -> bool:
60-
if (spec := find_spec(module_name)) and (module_path := spec.origin):
61-
with open(module_path, "rb") as script:
62-
for line in script:
63-
line = b" ".join(line.split(b" ")).strip()
43+
spec = find_spec(module_name)
44+
45+
if spec and (module_path := spec.origin):
46+
with open(module_path, "r") as file:
47+
python_script = file.read()
6448

65-
if line and any(keyword in line for keyword in b_keywords):
66-
return True
49+
return bool(python_script) and any(
50+
keyword in python_script for keyword in keywords
51+
)
6752

6853
return False
6954

tests/utils/test_load_profile.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from injection import find_instance, injectable, mod
2+
from injection.utils import load_profile
3+
4+
5+
class TestLoadProfile:
6+
def test_load_profile_with_success(self):
7+
profile_name = "test"
8+
9+
@injectable
10+
class A: ...
11+
12+
@mod(profile_name).injectable(on=A)
13+
class B(A): ...
14+
15+
assert type(find_instance(A)) is A
16+
load_profile(profile_name)
17+
assert type(find_instance(A)) is B
18+
19+
# Cleaning
20+
mod().init_modules()
21+
22+
def test_load_profile_with_context_manager(self):
23+
profile_name = "test"
24+
25+
@injectable
26+
class A: ...
27+
28+
@mod(profile_name).injectable(on=A)
29+
class B(A): ...
30+
31+
assert type(find_instance(A)) is A
32+
33+
with load_profile(profile_name):
34+
assert type(find_instance(A)) is B
35+
36+
assert type(find_instance(A)) is A

0 commit comments

Comments
 (0)