Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,11 +816,11 @@ def load_profile(self, *names: str) -> ContextManager[Self]:
self.unlock().init_modules(*modules)

@contextmanager
def cleaner() -> Iterator[Self]:
def unload() -> Iterator[Self]:
yield self
self.unlock().init_modules()

return cleaner()
return unload()

async def all_ready(self) -> None:
for broker in self.__brokers:
Expand Down
8 changes: 8 additions & 0 deletions injection/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def load_modules(
) -> Self:
return self.setup(lambda: loader.load(*packages))

def load_profile(self, /, *names: str) -> Self:
@contextmanager
def decorator(module: Module) -> Iterator[None]:
with module.load_profile(*names):
yield

return self.decorate(decorator(self.module))

def setup(self, function: Callable[..., Any], /) -> Self:
@contextmanager
def decorator() -> Iterator[Any]:
Expand Down
14 changes: 13 additions & 1 deletion tests/entrypoint/test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterator
from contextlib import contextmanager

from injection import injectable
from injection import injectable, mod
from injection.entrypoint import Entrypoint


Expand Down Expand Up @@ -41,6 +41,18 @@ def function(service: Service) -> bool:
entrypoint = Entrypoint(function).inject()
assert entrypoint()

def test_load_profile_with_success_return_entrypoint(self):
profile_name = "test"

@mod(profile_name).injectable
class Service: ...

def function(service: Service) -> bool:
return isinstance(service, Service)

entrypoint = Entrypoint(function).inject().load_profile(profile_name)
assert entrypoint()

def test_setup_with_success_return_entrypoint(self):
count = 0

Expand Down