Skip to content

Commit 81e22fb

Browse files
sobolevnRobertoPrevato
authored andcommitted
Do not ignore _globalns that is set via inject()
1 parent ba7ab21 commit 81e22fb

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

rodi/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def _get_obj_locals(obj) -> dict[str, Any] | None:
8484
return getattr(obj, "_locals", None)
8585

8686

87+
def _get_obj_globals(obj) -> dict[str, Any]:
88+
return getattr(obj, "_globals", {})
89+
90+
8791
def class_name(input_type):
8892
if input_type in {list, set} and str( # noqa: E721
8993
type(input_type) == "<class 'types.GenericAlias'>"
@@ -568,9 +572,11 @@ def _resolve_by_init_method(self, context: ResolutionContext):
568572
for key, value in sig.parameters.items()
569573
}
570574

575+
globalns = dict(vars(sys.modules[self.concrete_type.__module__]))
576+
globalns.update(_get_obj_globals(self.concrete_type))
571577
annotations = get_type_hints(
572578
self.concrete_type.__init__,
573-
vars(sys.modules[self.concrete_type.__module__]),
579+
globalns,
574580
_get_obj_locals(self.concrete_type),
575581
)
576582
for key, value in params.items():
@@ -646,9 +652,11 @@ def __call__(self, context: ResolutionContext):
646652
chain.append(concrete_type)
647653

648654
if self._has_default_init():
655+
globalns = dict(vars(sys.modules[concrete_type.__module__]))
656+
globalns.update(_get_obj_globals(concrete_type))
649657
annotations = get_type_hints(
650658
concrete_type,
651-
vars(sys.modules[concrete_type.__module__]),
659+
globalns,
652660
_get_obj_locals(concrete_type),
653661
)
654662

tests/test_services.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,3 +2760,51 @@ async def test_nested_scope_async_1():
27602760
nested_scope_async(),
27612761
nested_scope_async(),
27622762
)
2763+
2764+
2765+
# Tests for inject(globalsns=...) being honoured during type resolution (#60)
2766+
2767+
2768+
def test_inject_globalsns_honoured_for_annotation_resolution():
2769+
"""
2770+
When a class uses a forward reference in a class-level annotation and the
2771+
type is provided via inject(globalsns=...), it should be resolved correctly.
2772+
"""
2773+
2774+
class LocalDep:
2775+
pass
2776+
2777+
@inject(globalsns={"LocalDep": LocalDep})
2778+
class Service:
2779+
dep: "LocalDep"
2780+
2781+
container = Container()
2782+
container.add_transient(LocalDep)
2783+
container.add_transient(Service)
2784+
provider = container.build_provider()
2785+
2786+
instance = provider.get(Service)
2787+
assert isinstance(instance.dep, LocalDep)
2788+
2789+
2790+
def test_inject_globalsns_honoured_for_init_resolution():
2791+
"""
2792+
When a class uses a forward reference in __init__ and the type is provided
2793+
via inject(globalsns=...), it should be resolved correctly.
2794+
"""
2795+
2796+
class LocalDep:
2797+
pass
2798+
2799+
@inject(globalsns={"LocalDep": LocalDep})
2800+
class Service:
2801+
def __init__(self, dep: "LocalDep") -> None:
2802+
self.dep = dep
2803+
2804+
container = Container()
2805+
container.add_transient(LocalDep)
2806+
container.add_transient(Service)
2807+
provider = container.build_provider()
2808+
2809+
instance = provider.get(Service)
2810+
assert isinstance(instance.dep, LocalDep)

0 commit comments

Comments
 (0)