77
88from openfisca_core .indexed_enums import config , EnumArray
99
10+ if typing .TYPE_CHECKING :
11+ IndexedEnumArray = numpy .object_
12+
1013
1114class Enum (enum .Enum ):
1215 """
13- Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items have an
14- index.
16+ Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items
17+ have an index.
1518 """
1619
1720 # Tweak enums to add an index attribute to each enum item
1821 def __init__ (self , name : str ) -> None :
19- # When the enum item is initialized, self._member_names_ contains the names of
20- # the previously initialized items, so its length is the index of this item.
22+ # When the enum item is initialized, self._member_names_ contains the
23+ # names of the previously initialized items, so its length is the index
24+ # of this item.
2125 self .index = len (self ._member_names_ )
2226
2327 # Bypass the slow Enum.__eq__
2428 __eq__ = object .__eq__
2529
26- # In Python 3, __hash__ must be defined if __eq__ is defined to stay hashable.
30+ # In Python 3, __hash__ must be defined if __eq__ is defined to stay
31+ # hashable.
2732 __hash__ = object .__hash__
2833
2934 @classmethod
3035 def encode (
3136 cls ,
3237 array : typing .Union [
3338 EnumArray ,
34- numpy .ndarray [ int ] ,
35- numpy .ndarray [ str ] ,
36- numpy . ndarray [ Enum ] ,
39+ numpy .int_ ,
40+ numpy .float_ ,
41+ IndexedEnumArray ,
3742 ],
3843 ) -> EnumArray :
3944 """
40- Encode a string numpy array, an enum item numpy array, or an int numpy array
41- into an :any:`EnumArray`. See :any:`EnumArray.decode` for decoding.
45+ Encode a string numpy array, an enum item numpy array, or an int numpy
46+ array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
47+ decoding.
4248
43- :param ndarray array: Array of string identifiers, or of enum items, to encode.
49+ :param ndarray array: Array of string identifiers, or of enum items, to
50+ encode.
4451
4552 :returns: An :any:`EnumArray` encoding the input array values.
4653 :rtype: :any:`EnumArray`
@@ -59,24 +66,31 @@ def encode(
5966 >>> encoded_array[0]
6067 2 # Encoded value
6168 """
62- if type (array ) is EnumArray :
69+ if isinstance (array , EnumArray ) :
6370 return array
6471
65- if array .dtype .kind in {'U' , 'S' }: # String array
72+ # String array
73+ if isinstance (array , numpy .ndarray ) and \
74+ array .dtype .kind in {'U' , 'S' }:
6675 array = numpy .select (
6776 [array == item .name for item in cls ],
6877 [item .index for item in cls ],
6978 ).astype (config .ENUM_ARRAY_DTYPE )
7079
71- elif array .dtype .kind == 'O' : # Enum items arrays
80+ # Enum items arrays
81+ elif isinstance (array , numpy .ndarray ) and \
82+ array .dtype .kind == 'O' :
7283 # Ensure we are comparing the comparable. The problem this fixes:
7384 # On entering this method "cls" will generally come from
74- # variable.possible_values, while the array values may come from directly
75- # importing a module containing an Enum class. However, variables (and
76- # hence their possible_values) are loaded by a call to load_module, which
77- # gives them a different identity from the ones imported in the usual way.
78- # So, instead of relying on the "cls" passed in, we use only its name to
79- # check that the values in the array, if non-empty, are of the right type.
85+ # variable.possible_values, while the array values may come from
86+ # directly importing a module containing an Enum class. However,
87+ # variables (and hence their possible_values) are loaded by a call
88+ # to load_module, which gives them a different identity from the
89+ # ones imported in the usual way.
90+ #
91+ # So, instead of relying on the "cls" passed in, we use only its
92+ # name to check that the values in the array, if non-empty, are of
93+ # the right type.
8094 if len (array ) > 0 and cls .__name__ is array [0 ].__class__ .__name__ :
8195 cls = array [0 ].__class__
8296
0 commit comments