Skip to content

Commit f3d77c1

Browse files
author
Mauko Quiroga
committed
Fix indexed enum typings
1 parent 1eab6d7 commit f3d77c1

2 files changed

Lines changed: 48 additions & 30 deletions

File tree

openfisca_core/indexed_enums/enum.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,47 @@
77

88
from openfisca_core.indexed_enums import config, EnumArray
99

10+
if typing.TYPE_CHECKING:
11+
IndexedEnumArray = numpy.object_
12+
1013

1114
class Enum(enum.Enum):
1215
"""
13-
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items have an
14-
index.
16+
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items
17+
have an index.
1518
"""
1619

1720
# Tweak enums to add an index attribute to each enum item
1821
def __init__(self, name: str) -> None:
19-
# When the enum item is initialized, self._member_names_ contains the names of
20-
# the previously initialized items, so its length is the index of this item.
22+
# When the enum item is initialized, self._member_names_ contains the
23+
# names of the previously initialized items, so its length is the index
24+
# of this item.
2125
self.index = len(self._member_names_)
2226

2327
# Bypass the slow Enum.__eq__
2428
__eq__ = object.__eq__
2529

26-
# In Python 3, __hash__ must be defined if __eq__ is defined to stay hashable.
30+
# In Python 3, __hash__ must be defined if __eq__ is defined to stay
31+
# hashable.
2732
__hash__ = object.__hash__
2833

2934
@classmethod
3035
def encode(
3136
cls,
3237
array: typing.Union[
3338
EnumArray,
34-
numpy.ndarray[int],
35-
numpy.ndarray[str],
36-
numpy.ndarray[Enum],
39+
numpy.int_,
40+
numpy.float_,
41+
IndexedEnumArray,
3742
],
3843
) -> EnumArray:
3944
"""
40-
Encode a string numpy array, an enum item numpy array, or an int numpy array
41-
into an :any:`EnumArray`. See :any:`EnumArray.decode` for decoding.
45+
Encode a string numpy array, an enum item numpy array, or an int numpy
46+
array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
47+
decoding.
4248
43-
:param ndarray array: Array of string identifiers, or of enum items, to encode.
49+
:param ndarray array: Array of string identifiers, or of enum items, to
50+
encode.
4451
4552
:returns: An :any:`EnumArray` encoding the input array values.
4653
:rtype: :any:`EnumArray`
@@ -59,24 +66,31 @@ def encode(
5966
>>> encoded_array[0]
6067
2 # Encoded value
6168
"""
62-
if type(array) is EnumArray:
69+
if isinstance(array, EnumArray):
6370
return array
6471

65-
if array.dtype.kind in {'U', 'S'}: # String array
72+
# String array
73+
if isinstance(array, numpy.ndarray) and \
74+
array.dtype.kind in {'U', 'S'}:
6675
array = numpy.select(
6776
[array == item.name for item in cls],
6877
[item.index for item in cls],
6978
).astype(config.ENUM_ARRAY_DTYPE)
7079

71-
elif array.dtype.kind == 'O': # Enum items arrays
80+
# Enum items arrays
81+
elif isinstance(array, numpy.ndarray) and \
82+
array.dtype.kind == 'O':
7283
# Ensure we are comparing the comparable. The problem this fixes:
7384
# On entering this method "cls" will generally come from
74-
# variable.possible_values, while the array values may come from directly
75-
# importing a module containing an Enum class. However, variables (and
76-
# hence their possible_values) are loaded by a call to load_module, which
77-
# gives them a different identity from the ones imported in the usual way.
78-
# So, instead of relying on the "cls" passed in, we use only its name to
79-
# check that the values in the array, if non-empty, are of the right type.
85+
# variable.possible_values, while the array values may come from
86+
# directly importing a module containing an Enum class. However,
87+
# variables (and hence their possible_values) are loaded by a call
88+
# to load_module, which gives them a different identity from the
89+
# ones imported in the usual way.
90+
#
91+
# So, instead of relying on the "cls" passed in, we use only its
92+
# name to check that the values in the array, if non-empty, are of
93+
# the right type.
8094
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
8195
cls = array[0].__class__
8296

openfisca_core/indexed_enums/enum_array.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
if typing.TYPE_CHECKING:
88
from openfisca_core.indexed_enums import Enum
99

10+
IndexedEnumArray = numpy.object_
11+
1012

1113
class EnumArray(numpy.ndarray):
1214
"""
@@ -19,24 +21,24 @@ class EnumArray(numpy.ndarray):
1921
# To read more about the two following methods, see:
2022
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array.
2123
def __new__(
22-
cls,
23-
input_array: numpy.ndarray[int],
24+
cls: typing.Type[EnumArray],
25+
input_array: numpy.int_,
2426
possible_values: typing.Optional[typing.Type[Enum]] = None,
2527
) -> EnumArray:
2628
obj = numpy.asarray(input_array).view(cls)
2729
obj.possible_values = possible_values
2830
return obj
2931

3032
# See previous comment
31-
def __array_finalize__(self, obj: typing.Optional[numpy.ndarray[int]]) -> None:
33+
def __array_finalize__(self, obj: typing.Optional[numpy.int_]) -> None:
3234
if obj is None:
3335
return
3436

3537
self.possible_values = getattr(obj, "possible_values", None)
3638

3739
def __eq__(self, other: typing.Any) -> bool:
38-
# When comparing to an item of self.possible_values, use the item index to
39-
# speed up the comparison.
40+
# When comparing to an item of self.possible_values, use the item index
41+
# to speed up the comparison.
4042
if other.__class__.__name__ is self.possible_values.__name__:
4143
# Use view(ndarray) so that the result is a classic ndarray, not an
4244
# EnumArray.
@@ -49,8 +51,8 @@ def __ne__(self, other: typing.Any) -> bool:
4951

5052
def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
5153
raise TypeError(
52-
"Forbidden operation. The only operations allowed on EnumArrays are "
53-
"'==' and '!='.",
54+
"Forbidden operation. The only operations allowed on EnumArrays "
55+
"are '==' and '!='.",
5456
)
5557

5658
__add__ = _forbidden_operation
@@ -62,7 +64,7 @@ def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
6264
__and__ = _forbidden_operation
6365
__or__ = _forbidden_operation
6466

65-
def decode(self) -> numpy.ndarray[Enum]:
67+
def decode(self) -> IndexedEnumArray:
6668
"""
6769
Return the array of enum items corresponding to self.
6870
@@ -72,14 +74,16 @@ def decode(self) -> numpy.ndarray[Enum]:
7274
>>> enum_array[0]
7375
>>> 2 # Encoded value
7476
>>> enum_array.decode()[0]
75-
<HousingOccupancyStatus.free_lodger: 'Free lodger'> # Decoded value : enum item
77+
<HousingOccupancyStatus.free_lodger: 'Free lodger'>
78+
79+
Decoded value: enum item
7680
"""
7781
return numpy.select(
7882
[self == item.index for item in self.possible_values],
7983
list(self.possible_values),
8084
)
8185

82-
def decode_to_str(self) -> numpy.ndarray[str]:
86+
def decode_to_str(self) -> numpy.str_:
8387
"""
8488
Return the array of string identifiers corresponding to self.
8589

0 commit comments

Comments
 (0)