Skip to content

Commit 33aeeaf

Browse files
authored
fix: 🐛 @constant now works with async recipes
1 parent bce9bb1 commit 33aeeaf

3 files changed

Lines changed: 52 additions & 16 deletions

File tree

injection/_core/common/lazy.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
1-
from collections.abc import Callable, Iterator
1+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
22
from functools import partial
33

4-
from injection._core.common.invertible import Invertible, SimpleInvertible
4+
from injection._core.common.invertible import Invertible
55

66

7-
def lazy[T](factory: Callable[..., T]) -> Invertible[T]:
7+
def lazy[T](factory: Callable[..., T]) -> Callable[[], T]:
88
def cache() -> Iterator[T]:
9-
nonlocal factory
109
value = factory()
11-
del factory
10+
while True:
11+
yield value
12+
13+
return partial(next, cache())
1214

15+
16+
def alazy[T](factory: Callable[..., Awaitable[T]]) -> Callable[[], Awaitable[T]]:
17+
async def cache() -> AsyncIterator[T]:
18+
value = await factory()
1319
while True:
1420
yield value
1521

16-
getter = partial(next, cache())
17-
return SimpleInvertible(getter)
22+
return partial(_anext, cache())
1823

1924

2025
class Lazy[T](Invertible[T]):
21-
__slots__ = ("__invertible", "__is_set")
26+
__slots__ = ("__get", "__is_set")
2227

23-
__invertible: Invertible[T]
28+
__get: Callable[[], T]
2429
__is_set: bool
2530

2631
def __init__(self, factory: Callable[..., T]) -> None:
2732
@lazy
28-
def invertible() -> T:
33+
def get() -> T:
2934
value = factory()
3035
self.__is_set = True
3136
return value
3237

33-
self.__invertible = invertible
38+
self.__get = get
3439
self.__is_set = False
3540

3641
def __invert__(self) -> T:
37-
return ~self.__invertible
42+
return self.__get()
3843

3944
@property
4045
def is_set(self) -> bool:
4146
return self.__is_set
47+
48+
49+
async def _anext[T](async_iterator: AsyncIterator[T]) -> T:
50+
return await anext(async_iterator)

injection/_core/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from injection._core.common.event import Event, EventChannel, EventListener
5151
from injection._core.common.invertible import Invertible, SimpleInvertible
5252
from injection._core.common.key import new_short_key
53-
from injection._core.common.lazy import Lazy, lazy
53+
from injection._core.common.lazy import Lazy, alazy, lazy
5454
from injection._core.common.threading import get_lock
5555
from injection._core.common.type import (
5656
InputType,
@@ -512,9 +512,9 @@ def constant[**P, T](
512512
mode: Mode | ModeStr = Mode.get_default(),
513513
) -> Any:
514514
def decorator(wp: Recipe[P, T]) -> Recipe[P, T]:
515-
lazy_instance = lazy(wp)
515+
recipe: Recipe[[], T] = alazy(wp) if iscoroutinefunction(wp) else lazy(wp) # type: ignore[arg-type]
516516
self.injectable(
517-
lambda: ~lazy_instance,
517+
recipe,
518518
ignore_type_hint=True,
519519
inject=False,
520520
on=(wp, on),

tests/test_constant.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from injection import constant, get_instance
3+
from injection import aget_instance, constant, get_instance
44

55

66
class TestConstant:
@@ -12,6 +12,33 @@ class SomeInjectable: ...
1212
instance_2 = get_instance(SomeInjectable)
1313
assert instance_1 is instance_2 is not None
1414

15+
def test_constant_with_recipe(self):
16+
class SomeClass: ...
17+
18+
@constant
19+
def recipe() -> SomeClass:
20+
return SomeClass()
21+
22+
instance_1 = get_instance(SomeClass)
23+
instance_2 = get_instance(SomeClass)
24+
assert instance_1 is instance_2
25+
assert isinstance(instance_1, SomeClass)
26+
27+
async def test_constant_with_async_recipe(self):
28+
class SomeClass: ...
29+
30+
@constant
31+
async def recipe() -> SomeClass:
32+
return SomeClass()
33+
34+
with pytest.raises(RuntimeError):
35+
get_instance(SomeClass)
36+
37+
instance_1 = await aget_instance(SomeClass)
38+
instance_2 = await aget_instance(SomeClass)
39+
assert instance_1 is instance_2
40+
assert isinstance(instance_1, SomeClass)
41+
1542
def test_constant_with_on(self):
1643
class A: ...
1744

0 commit comments

Comments
 (0)