Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AsyncIterator,
Awaitable,
Callable,
Container,
Generator,
Iterable,
Iterator,
Expand All @@ -18,6 +19,7 @@
from enum import StrEnum
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
from inspect import (
BoundArguments,
Signature,
isasyncgenfunction,
isclass,
Expand Down Expand Up @@ -739,28 +741,32 @@ def mod(name: str | None = None, /) -> Module:
class Dependencies:
lazy_mapping: Lazy[Mapping[str, Injectable[Any]]]

def __iter__(self) -> Iterator[tuple[str, Any]]:
for name, injectable in self.items():
def iter(self, exclude: Container[str]) -> Iterator[tuple[str, Any]]:
for name, injectable in self.items(exclude):
with suppress(SkipInjectable):
yield name, injectable.get_instance()

async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
for name, injectable in self.items():
async def aiter(self, exclude: Container[str]) -> AsyncIterator[tuple[str, Any]]:
for name, injectable in self.items(exclude):
with suppress(SkipInjectable):
yield name, await injectable.aget_instance()

@property
def are_resolved(self) -> bool:
return self.lazy_mapping.is_set

async def aget_arguments(self) -> dict[str, Any]:
return {key: value async for key, value in self}
async def aget_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
return {key: value async for key, value in self.aiter(exclude)}

def get_arguments(self) -> dict[str, Any]:
return dict(self)
def get_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
return dict(self.iter(exclude))

def items(self) -> Iterator[tuple[str, Injectable[Any]]]:
return iter((~self.lazy_mapping).items())
def items(self, exclude: Container[str]) -> Iterator[tuple[str, Injectable[Any]]]:
return (
(name, injectable)
for name, injectable in (~self.lazy_mapping).items()
if name not in exclude
)

@classmethod
def from_iterable(cls, iterable: Iterable[tuple[str, Injectable[Any]]]) -> Self:
Expand Down Expand Up @@ -858,21 +864,21 @@ def signature(self) -> Signature:
def wrapped(self) -> Callable[P, T]:
return self.__wrapped

async def abind(
self,
args: Iterable[Any] = (),
kwargs: Mapping[str, Any] | None = None,
) -> Arguments:
additional_arguments = await self.__dependencies.aget_arguments()
return self.__bind(args, kwargs, additional_arguments)
async def abind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
arguments = self.__get_arguments(args, kwargs)
dependencies = await self.__dependencies.aget_arguments(exclude=arguments)
if dependencies:
return self.__merge_arguments(arguments, dependencies)

def bind(
self,
args: Iterable[Any] = (),
kwargs: Mapping[str, Any] | None = None,
) -> Arguments:
additional_arguments = self.__dependencies.get_arguments()
return self.__bind(args, kwargs, additional_arguments)
return Arguments(args, kwargs)

def bind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
arguments = self.__get_arguments(args, kwargs)
dependencies = self.__dependencies.get_arguments(exclude=arguments)
if dependencies:
return self.__merge_arguments(arguments, dependencies)

return Arguments(args, kwargs)

async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
with self.__lock:
Expand Down Expand Up @@ -921,20 +927,20 @@ def _(self, event: ModuleEvent, /) -> Iterator[None]:
yield
self.update(event.module)

def __bind(
def __get_arguments(
self,
args: Iterable[Any],
kwargs: Mapping[str, Any] | None,
additional_arguments: dict[str, Any] | None,
) -> Arguments:
if kwargs is None:
kwargs = {}

if not additional_arguments:
return Arguments(args, kwargs)

kwargs: Mapping[str, Any],
) -> dict[str, Any]:
bound = self.signature.bind_partial(*args, **kwargs)
bound.arguments = bound.arguments | additional_arguments | bound.arguments
return bound.arguments

def __merge_arguments(
self,
arguments: dict[str, Any],
additional_arguments: dict[str, Any],
) -> Arguments:
bound = BoundArguments(self.signature, additional_arguments | arguments) # type: ignore[arg-type]
return Arguments(bound.args, bound.kwargs)

def __run_tasks(self) -> None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,18 @@ def _method(this=..., _: SomeInjectable = ...):
@injectable
class A:
method = _method

def test_inject_with_passing_argument_do_not_lock_module(self, module):
assert not module.is_locked

@module.singleton
class A: ...

@module.inject
def function(a: A): ...

function(A())
assert not module.is_locked

function()
assert module.is_locked
Loading