11from __future__ import annotations
22
3- import threading
43from abc import ABC , abstractmethod
54from collections import OrderedDict , deque
65from collections .abc import (
1413 Iterator ,
1514 Mapping ,
1615)
17- from contextlib import asynccontextmanager , contextmanager , nullcontext , suppress
16+ from contextlib import asynccontextmanager , contextmanager , suppress
1817from dataclasses import dataclass , field
1918from enum import StrEnum
2019from functools import partial , partialmethod , singledispatchmethod , update_wrapper
5150from injection ._core .common .event import Event , EventChannel , EventListener
5251from injection ._core .common .invertible import Invertible , SimpleInvertible
5352from injection ._core .common .key import new_short_key
54- from injection ._core .common .lazy import Lazy , alazy , lazy
53+ from injection ._core .common .lazy import Lazy , lazy
54+ from injection ._core .common .threading import get_lock
5555from injection ._core .common .type import (
5656 InputType ,
5757 TypeInfo ,
@@ -617,35 +617,48 @@ def make_async_factory[T](
617617 )
618618 return factory .__inject_metadata__ .acall
619619
620- async def afind_instance [T ](self , cls : InputType [T ]) -> T :
621- injectable = self [cls ]
622- return await injectable .aget_instance ()
620+ async def afind_instance [T ](
621+ self ,
622+ cls : InputType [T ],
623+ * ,
624+ threadsafe : bool = False ,
625+ ) -> T :
626+ with get_lock (threadsafe ):
627+ injectable = self [cls ]
628+ return await injectable .aget_instance ()
623629
624- def find_instance [T ](self , cls : InputType [T ]) -> T :
625- injectable = self [cls ]
626- return injectable .get_instance ()
630+ def find_instance [T ](self , cls : InputType [T ], * , threadsafe : bool = False ) -> T :
631+ with get_lock (threadsafe ):
632+ injectable = self [cls ]
633+ return injectable .get_instance ()
627634
628635 @overload
629636 async def aget_instance [T , Default ](
630637 self ,
631638 cls : InputType [T ],
632639 default : Default ,
640+ * ,
641+ threadsafe : bool = ...,
633642 ) -> T | Default : ...
634643
635644 @overload
636645 async def aget_instance [T ](
637646 self ,
638647 cls : InputType [T ],
639- default : None = ...,
640- ) -> T | None : ...
648+ default : T = ...,
649+ * ,
650+ threadsafe : bool = ...,
651+ ) -> T : ...
641652
642653 async def aget_instance [T , Default ](
643654 self ,
644655 cls : InputType [T ],
645- default : Default | None = None ,
646- ) -> T | Default | None :
656+ default : Default = NotImplemented ,
657+ * ,
658+ threadsafe : bool = False ,
659+ ) -> T | Default :
647660 try :
648- return await self .afind_instance (cls )
661+ return await self .afind_instance (cls , threadsafe = threadsafe )
649662 except (KeyError , SkipInjectable ):
650663 return default
651664
@@ -654,22 +667,28 @@ def get_instance[T, Default](
654667 self ,
655668 cls : InputType [T ],
656669 default : Default ,
670+ * ,
671+ threadsafe : bool = ...,
657672 ) -> T | Default : ...
658673
659674 @overload
660675 def get_instance [T ](
661676 self ,
662677 cls : InputType [T ],
663- default : None = ...,
664- ) -> T | None : ...
678+ default : T = ...,
679+ * ,
680+ threadsafe : bool = ...,
681+ ) -> T : ...
665682
666683 def get_instance [T , Default ](
667684 self ,
668685 cls : InputType [T ],
669- default : Default | None = None ,
670- ) -> T | Default | None :
686+ default : Default = NotImplemented ,
687+ * ,
688+ threadsafe : bool = False ,
689+ ) -> T | Default :
671690 try :
672- return self .find_instance (cls )
691+ return self .find_instance (cls , threadsafe = threadsafe )
673692 except (KeyError , SkipInjectable ):
674693 return default
675694
@@ -679,29 +698,29 @@ def aget_lazy_instance[T, Default](
679698 cls : InputType [T ],
680699 default : Default ,
681700 * ,
682- cache : bool = ...,
701+ threadsafe : bool = ...,
683702 ) -> Awaitable [T | Default ]: ...
684703
685704 @overload
686705 def aget_lazy_instance [T ](
687706 self ,
688707 cls : InputType [T ],
689- default : None = ...,
708+ default : T = ...,
690709 * ,
691- cache : bool = ...,
692- ) -> Awaitable [T | None ]: ...
710+ threadsafe : bool = ...,
711+ ) -> Awaitable [T ]: ...
693712
694713 def aget_lazy_instance [T , Default ](
695714 self ,
696715 cls : InputType [T ],
697- default : Default | None = None ,
716+ default : Default = NotImplemented ,
698717 * ,
699- cache : bool = False ,
700- ) -> Awaitable [T | Default | None ]:
701- if cache :
702- return alazy ( lambda : self . aget_instance ( cls , default ))
703-
704- function = self . make_injected_function ( lambda instance = default : instance )
718+ threadsafe : bool = False ,
719+ ) -> Awaitable [T | Default ]:
720+ function = self . make_injected_function (
721+ lambda instance = default : instance ,
722+ threadsafe = threadsafe ,
723+ )
705724 metadata = function .__inject_metadata__ .set_owner (cls )
706725 return SimpleAwaitable (metadata .acall )
707726
@@ -711,29 +730,29 @@ def get_lazy_instance[T, Default](
711730 cls : InputType [T ],
712731 default : Default ,
713732 * ,
714- cache : bool = ...,
733+ threadsafe : bool = ...,
715734 ) -> Invertible [T | Default ]: ...
716735
717736 @overload
718737 def get_lazy_instance [T ](
719738 self ,
720739 cls : InputType [T ],
721- default : None = ...,
740+ default : T = ...,
722741 * ,
723- cache : bool = ...,
724- ) -> Invertible [T | None ]: ...
742+ threadsafe : bool = ...,
743+ ) -> Invertible [T ]: ...
725744
726745 def get_lazy_instance [T , Default ](
727746 self ,
728747 cls : InputType [T ],
729- default : Default | None = None ,
748+ default : Default = NotImplemented ,
730749 * ,
731- cache : bool = False ,
732- ) -> Invertible [T | Default | None ]:
733- if cache :
734- return lazy ( lambda : self . get_instance ( cls , default ))
735-
736- function = self . make_injected_function ( lambda instance = default : instance )
750+ threadsafe : bool = False ,
751+ ) -> Invertible [T | Default ]:
752+ function = self . make_injected_function (
753+ lambda instance = default : instance ,
754+ threadsafe = threadsafe ,
755+ )
737756 metadata = function .__inject_metadata__ .set_owner (cls )
738757 return SimpleInvertible (metadata .call )
739758
@@ -996,7 +1015,7 @@ class InjectMetadata[**P, T](Caller[P, T], EventListener):
9961015
9971016 def __init__ (self , wrapped : Callable [P , T ], / , threadsafe : bool ) -> None :
9981017 self .__dependencies = Dependencies .empty ()
999- self .__lock = threading . RLock () if threadsafe else nullcontext ( )
1018+ self .__lock = get_lock ( threadsafe )
10001019 self .__owner = None
10011020 self .__tasks = deque ()
10021021 self .__wrapped = wrapped
0 commit comments