2222import threading
2323import types
2424from abc import ABCMeta , abstractmethod
25- from collections import namedtuple
25+ from dataclasses import dataclass
2626from typing import (
27+ TYPE_CHECKING ,
2728 Any ,
2829 Callable ,
29- cast ,
3030 Dict ,
31+ Generator ,
3132 Generic ,
3233 Iterable ,
3334 List ,
3435 Optional ,
35- overload ,
3636 Set ,
3737 Tuple ,
3838 Type ,
3939 TypeVar ,
40- TYPE_CHECKING ,
4140 Union ,
41+ cast ,
42+ get_args ,
43+ overload ,
4244)
4345
4446try :
5153# canonical. Since this typing_extensions import is only for mypy it'll work even without
5254# typing_extensions actually installed so all's good.
5355if TYPE_CHECKING :
54- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
56+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
5557else :
5658 # Ignoring errors here as typing_extensions stub doesn't know about those things yet
5759 try :
58- from typing import _AnnotatedAlias , Annotated , get_type_hints
60+ from typing import Annotated , _AnnotatedAlias , get_type_hints
5961 except ImportError :
60- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
62+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
6163
6264
6365__author__ = 'Alec Thomas <alec@swapoff.org>'
64- __version__ = '0.22 .0'
66+ __version__ = '0.23 .0'
6567__version_tag__ = ''
6668
6769log = logging .getLogger ('injector' )
@@ -244,6 +246,10 @@ class UnknownArgument(Error):
244246 """Tried to mark an unknown argument as noninjectable."""
245247
246248
249+ class InvalidInterface (Error ):
250+ """Cannot bind to the specified interface."""
251+
252+
247253class Provider (Generic [T ]):
248254 """Provides class instances."""
249255
@@ -335,40 +341,110 @@ def __repr__(self) -> str:
335341
336342
337343@private
338- class ListOfProviders (Provider , Generic [T ]):
344+ class MultiBinder (Provider , Generic [T ]):
339345 """Provide a list of instances via other Providers."""
340346
341- _providers : List [ Provider [ T ]]
347+ __metaclass__ = ABCMeta
342348
343- def __init__ (self ) -> None :
344- self ._providers = []
349+ _multi_bindings : List ['Binding' ]
345350
346- def append (self , provider : Provider [T ]) -> None :
347- self ._providers .append (provider )
351+ def __init__ (self , parent : 'Binder' ) -> None :
352+ self ._multi_bindings = []
353+ self ._binder = Binder (parent .injector , auto_bind = False , parent = parent )
354+
355+ @abstractmethod
356+ def multibind (
357+ self , interface : type , to : Any , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ]
358+ ) -> None :
359+ raise NotImplementedError
360+
361+ def append (self , provider : Provider [T ], scope : Type ['Scope' ]) -> None :
362+ # HACK: generate a pseudo-type for this element in the list.
363+ # This is needed for scopes to work properly. Some, like the Singleton scope,
364+ # key instances by type, so we need one that is unique to this binding.
365+ pseudo_type = type (f"multibind-type-{ id (provider )} " , (provider .__class__ ,), {})
366+ self ._multi_bindings .append (Binding (pseudo_type , provider , scope ))
367+
368+ def get_scoped_providers (self , injector : 'Injector' ) -> Generator [Provider [T ], None , None ]:
369+ for binding in self ._multi_bindings :
370+ scope_binding , _ = self ._binder .get_binding (binding .scope )
371+ scope_instance : Scope = scope_binding .provider .get (injector )
372+ provider_instance = scope_instance .get (binding .interface , binding .provider )
373+ yield provider_instance
348374
349375 def __repr__ (self ) -> str :
350- return '%s(%r)' % (type (self ).__name__ , self ._providers )
376+ return '%s(%r)' % (type (self ).__name__ , self ._multi_bindings )
351377
352378
353- class MultiBindProvider (ListOfProviders [List [T ]]):
379+ class MultiBindProvider (MultiBinder [List [T ]]):
354380 """Used by :meth:`Binder.multibind` to flatten results of providers that
355381 return sequences."""
356382
383+ def multibind (
384+ self , interface : type , to : Any , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ]
385+ ) -> None :
386+ try :
387+ element_type = get_args (_punch_through_alias (interface ))[0 ]
388+ except IndexError :
389+ raise InvalidInterface (f"Use typing.List[T] or list[T] to specify the element type of the list" )
390+ if isinstance (to , list ):
391+ for element in to :
392+ element_binding = self ._binder .create_binding (element_type , element , scope )
393+ self .append (element_binding .provider , element_binding .scope )
394+ else :
395+ element_binding = self ._binder .create_binding (interface , to , scope )
396+ self .append (element_binding .provider , element_binding .scope )
397+
357398 def get (self , injector : 'Injector' ) -> List [T ]:
358- return [i for provider in self ._providers for i in provider .get (injector )]
399+ result : List [T ] = []
400+ for provider in self .get_scoped_providers (injector ):
401+ instances : List [T ] = _ensure_iterable (provider .get (injector ))
402+ result .extend (instances )
403+ return result
359404
360405
361- class MapBindProvider (ListOfProviders [Dict [str , T ]]):
406+ class MapBindProvider (MultiBinder [Dict [str , T ]]):
362407 """A provider for map bindings."""
363408
409+ def multibind (
410+ self , interface : type , to : Any , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ]
411+ ) -> None :
412+ try :
413+ value_type = get_args (_punch_through_alias (interface ))[1 ]
414+ except IndexError :
415+ raise InvalidInterface (
416+ f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
417+ )
418+ if isinstance (to , dict ):
419+ for key , value in to .items ():
420+ element_binding = self ._binder .create_binding (value_type , value , scope )
421+ self .append (KeyValueProvider (key , element_binding .provider ), element_binding .scope )
422+ else :
423+ element_binding = self ._binder .create_binding (interface , to , scope )
424+ self .append (element_binding .provider , element_binding .scope )
425+
364426 def get (self , injector : 'Injector' ) -> Dict [str , T ]:
365427 map : Dict [str , T ] = {}
366- for provider in self ._providers :
428+ for provider in self .get_scoped_providers ( injector ) :
367429 map .update (provider .get (injector ))
368430 return map
369431
370432
371- _BindingBase = namedtuple ('_BindingBase' , 'interface provider scope' )
433+ @private
434+ class KeyValueProvider (Provider [Dict [str , T ]]):
435+ def __init__ (self , key : str , inner_provider : Provider [T ]) -> None :
436+ self ._key = key
437+ self ._provider = inner_provider
438+
439+ def get (self , injector : 'Injector' ) -> Dict [str , T ]:
440+ return {self ._key : self ._provider .get (injector )}
441+
442+
443+ @dataclass
444+ class _BindingBase :
445+ interface : type
446+ provider : Provider
447+ scope : Type ['Scope' ]
372448
373449
374450@private
@@ -468,7 +544,7 @@ def bind(
468544 def multibind (
469545 self ,
470546 interface : Type [List [T ]],
471- to : Union [List [T ] , Callable [..., List [T ]], Provider [List [T ]]],
547+ to : Union [List [Union [ T , Type [ T ]]] , Callable [..., List [T ]], Provider [List [T ]], Type [ T ]],
472548 scope : Union [Type ['Scope' ], 'ScopeDecorator' , None ] = None ,
473549 ) -> None : # pragma: no cover
474550 pass
@@ -477,7 +553,7 @@ def multibind(
477553 def multibind (
478554 self ,
479555 interface : Type [Dict [K , V ]],
480- to : Union [Dict [K , V ], Callable [..., Dict [K , V ]], Provider [Dict [K , V ]]],
556+ to : Union [Dict [K , Union [ V , Type [ V ]] ], Callable [..., Dict [K , V ]], Provider [Dict [K , V ]]],
481557 scope : Union [Type ['Scope' ], 'ScopeDecorator' , None ] = None ,
482558 ) -> None : # pragma: no cover
483559 pass
@@ -489,42 +565,52 @@ def multibind(
489565
490566 A multi-binding contributes values to a list or to a dictionary. For example::
491567
492- binder.multibind(List[str], to=['some', 'strings'])
493- binder.multibind(List[str], to=['other', 'strings'])
494- injector.get(List[str]) # ['some', 'strings', 'other', 'strings']
568+ binder.multibind(list[Interface], to=A)
569+ binder.multibind(list[Interface], to=[B, C()])
570+ injector.get(list[Interface])
571+ # [<A object at 0x1000>, <B object at 0x2000>, <C object at 0x3000>]
495572
496- binder.multibind(Dict[str, int], to={'key': 11})
497- binder.multibind(Dict[str, int], to={'other_key': 33})
498- injector.get(Dict[str, int]) # {'key': 11, 'other_key': 33}
573+ binder.multibind(dict[str, Interface], to={'key': A})
574+ binder.multibind(dict[str, Interface], to={'other_key': B})
575+ injector.get(dict[str, Interface])
576+ # {'key': <A object at 0x1000>, 'other_key': <B object at 0x2000>}
499577
500578 .. versionchanged:: 0.17.0
501579 Added support for using `typing.Dict` and `typing.List` instances as interfaces.
502580 Deprecated support for `MappingKey`, `SequenceKey` and single-item lists and
503581 dictionaries as interfaces.
504582
505- :param interface: typing.Dict or typing.List instance to bind to.
506- :param to: Instance, class to bind to, or an explicit :class:`Provider`
507- subclass. Must provide a list or a dictionary, depending on the interface.
583+ :param interface: A generic list[T] or dict[str, T] type to bind to.
584+
585+ :param to: A list/dict to bind to, where the values are either instances or classes implementing T.
586+ Can also be an explicit :class:`Provider` or a callable that returns a list/dict.
587+ For lists, this can also be a class implementing T (e.g. multibind(list[T], to=A))
588+
508589 :param scope: Optional Scope in which to bind.
509590 """
591+ multi_binder = self ._get_multi_binder (interface )
592+ multi_binder .multibind (interface , to , scope )
593+
594+ def _get_multi_binder (self , interface : type ) -> MultiBinder :
595+ multi_binder : MultiBinder
510596 if interface not in self ._bindings :
511- provider : ListOfProviders
512597 if (
513598 isinstance (interface , dict )
514599 or isinstance (interface , type )
515600 and issubclass (interface , dict )
516601 or _get_origin (_punch_through_alias (interface )) is dict
517602 ):
518- provider = MapBindProvider ()
603+ multi_binder = MapBindProvider (self )
519604 else :
520- provider = MultiBindProvider ()
521- binding = self .create_binding (interface , provider , scope )
605+ multi_binder = MultiBindProvider (self )
606+ binding = self .create_binding (interface , multi_binder )
522607 self ._bindings [interface ] = binding
523608 else :
524609 binding = self ._bindings [interface ]
525- provider = binding .provider
526- assert isinstance (provider , ListOfProviders )
527- provider .append (self .provider_for (interface , to ))
610+ assert isinstance (binding .provider , MultiBinder )
611+ multi_binder = binding .provider
612+
613+ return multi_binder
528614
529615 def install (self , module : _InstallableModuleType ) -> None :
530616 """Install a module into this binder.
@@ -567,10 +653,10 @@ def create_binding(
567653 self , interface : type , to : Any = None , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ] = None
568654 ) -> Binding :
569655 provider = self .provider_for (interface , to )
570- scope = scope or getattr (to or interface , '__scope__' , NoScope )
656+ scope = scope or getattr (to or interface , '__scope__' , None )
571657 if isinstance (scope , ScopeDecorator ):
572658 scope = scope .scope
573- return Binding (interface , provider , scope )
659+ return Binding (interface , provider , scope or NoScope )
574660
575661 def provider_for (self , interface : Any , to : Any = None ) -> Provider :
576662 base_type = _punch_through_alias (interface )
@@ -652,7 +738,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
652738 # The special interface is added here so that requesting a special
653739 # interface with auto_bind disabled works
654740 if self ._auto_bind or self ._is_special_interface (interface ):
655- binding = ImplicitBinding (* self .create_binding (interface ))
741+ binding = ImplicitBinding (** self .create_binding (interface ). __dict__ )
656742 self ._bindings [interface ] = binding
657743 return binding , self
658744
@@ -696,6 +782,12 @@ def _is_specialization(cls: type, generic_class: Any) -> bool:
696782 return origin is generic_class or issubclass (origin , generic_class )
697783
698784
785+ def _ensure_iterable (item_or_list : Union [T , List [T ]]) -> List [T ]:
786+ if isinstance (item_or_list , list ):
787+ return item_or_list
788+ return [item_or_list ]
789+
790+
699791def _punch_through_alias (type_ : Any ) -> type :
700792 if (
701793 sys .version_info < (3 , 10 )
@@ -767,7 +859,7 @@ def __repr__(self) -> str:
767859class NoScope (Scope ):
768860 """An unscoped provider."""
769861
770- def get (self , unused_key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
862+ def get (self , key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
771863 return provider
772864
773865
@@ -1339,7 +1431,7 @@ def provide_strs_also(self) -> List[str]:
13391431def _mark_provider_function (function : Callable , * , allow_multi : bool ) -> None :
13401432 scope_ = getattr (function , '__scope__' , None )
13411433 try :
1342- annotations = get_type_hints (function )
1434+ annotations = get_type_hints (function , include_extras = True )
13431435 except NameError :
13441436 return_type = '__deferred__'
13451437 else :
0 commit comments