Skip to content

Commit 434c26b

Browse files
authored
Reject NaN values in numeric inputs (#493)
1 parent 06d6d01 commit 434c26b

4 files changed

Lines changed: 96 additions & 5 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Raised a clear error when numeric simulation inputs contain NaN values.

policyengine_core/holders/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLik
3030
3131
To read more about ``set_input`` attributes, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period>`_.
3232
"""
33-
array = holder._to_array(array)
33+
array = holder._to_array(array, validate_nan=True)
3434

3535
period_size = period.size
3636
period_unit = period.unit
@@ -70,6 +70,7 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike)
7070
"""
7171
if not isinstance(array, numpy.ndarray):
7272
array = numpy.array(array)
73+
array = holder._to_array(array, validate_nan=True)
7374
period_size = period.size
7475
period_unit = period.unit
7576

policyengine_core/holders/holder.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def set_input(
250250
return warnings.warn(warning_message, Warning)
251251
if self.variable.value_type in (float, int) and isinstance(array, str):
252252
array = tools.eval_expression(array)
253+
self._raise_if_input_contains_nan(numpy.asarray(array))
253254
simulation = getattr(self, "simulation", None)
254255
if simulation is not None:
255256
if not hasattr(simulation, "_user_input_keys"):
@@ -263,12 +264,29 @@ def set_input(
263264
and period.unit != self.variable.definition_period
264265
):
265266
return self.variable.set_input(self, period, array)
266-
return self._set(period, array, branch_name)
267+
return self._set(period, array, branch_name, validate_nan=True)
267268
finally:
268269
if simulation is not None:
269270
simulation._user_input_contexts.pop()
270271

271-
def _to_array(self, value: Any) -> ArrayLike:
272+
def _raise_if_input_contains_nan(self, value: ArrayLike) -> None:
273+
if self.variable.value_type not in (float, int):
274+
return
275+
value = numpy.asarray(value)
276+
try:
277+
if value.dtype.kind in ("O", "S", "U"):
278+
value = value.astype(float)
279+
contains_nan = numpy.isnan(value).any()
280+
except (TypeError, ValueError):
281+
return
282+
if contains_nan:
283+
raise ValueError(
284+
'Unable to set value for variable "{}", as the input contains NaN values.'.format(
285+
self.variable.name,
286+
)
287+
)
288+
289+
def _to_array(self, value: Any, validate_nan: bool = False) -> ArrayLike:
272290
if not isinstance(value, numpy.ndarray):
273291
value = numpy.asarray(value)
274292
if value.ndim == 0:
@@ -284,6 +302,8 @@ def _to_array(self, value: Any) -> ArrayLike:
284302
self.population.entity.plural,
285303
)
286304
)
305+
if validate_nan:
306+
self._raise_if_input_contains_nan(value)
287307
if self.variable.value_type == Enum:
288308
original_value = value
289309
value = self.variable.possible_values.encode(value)
@@ -301,16 +321,22 @@ def _to_array(self, value: Any) -> ArrayLike:
301321
value.dtype,
302322
)
303323
)
324+
if validate_nan:
325+
self._raise_if_input_contains_nan(value)
304326
return value
305327

306328
def _set(
307-
self, period: Period, value: ArrayLike, branch_name: str = "default"
329+
self,
330+
period: Period,
331+
value: ArrayLike,
332+
branch_name: str = "default",
333+
validate_nan: bool = False,
308334
) -> None:
309335
simulation = getattr(self, "simulation", None)
310336
user_input_contexts = getattr(simulation, "_user_input_contexts", None)
311337
if user_input_contexts and branch_name == "default":
312338
branch_name = user_input_contexts[-1]
313-
value = self._to_array(value)
339+
value = self._to_array(value, validate_nan=validate_nan)
314340
if self.variable.definition_period != periods.ETERNITY:
315341
if period is None:
316342
raise ValueError(

tests/core/test_holders.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,66 @@ def test_set_input_float_to_int(single):
216216
simulation.person.get_holder("age").set_input(period, age)
217217
result = simulation.calculate("age", period)
218218
assert result == numpy.asarray([50])
219+
220+
221+
def test__given_nan_float_array__then_set_input_raises_value_error(single):
222+
simulation = single
223+
224+
with pytest.raises(ValueError, match='variable "salary".*NaN'):
225+
simulation.set_input("salary", period, numpy.asarray([numpy.nan]))
226+
227+
228+
def test__given_nan_int_array__then_set_input_raises_value_error(single):
229+
simulation = single
230+
231+
with pytest.raises(ValueError, match='variable "age".*NaN'):
232+
simulation.set_input("age", period, numpy.asarray([numpy.nan]))
233+
234+
235+
def test__given_object_array_containing_nan__then_set_input_raises_value_error(
236+
single,
237+
):
238+
simulation = single
239+
age = numpy.asarray([numpy.nan], dtype=object)
240+
241+
with pytest.raises(ValueError, match='variable "age".*NaN'):
242+
simulation.set_input("age", period, age)
243+
244+
245+
def test__given_nan_yearly_input__then_set_input_divide_by_period_raises_value_error(
246+
single,
247+
):
248+
simulation = single
249+
salary_holder = simulation.person.get_holder("salary")
250+
251+
with pytest.raises(ValueError, match='variable "salary".*NaN'):
252+
holders.set_input_divide_by_period(
253+
salary_holder,
254+
periods.period(2017),
255+
numpy.asarray([numpy.nan]),
256+
)
257+
258+
259+
def test__given_nan_period_dispatch_input__then_helper_raises_value_error(
260+
single,
261+
):
262+
simulation = single
263+
age_holder = simulation.person.get_holder("age")
264+
265+
with pytest.raises(ValueError, match='variable "age".*NaN'):
266+
holders.set_input_dispatch_by_period(
267+
age_holder,
268+
periods.period(2017),
269+
numpy.asarray([numpy.nan]),
270+
)
271+
272+
273+
def test__given_nan_cache_value__then_put_in_cache_keeps_internal_write_allowed(
274+
single,
275+
):
276+
simulation = single
277+
salary_holder = simulation.person.get_holder("salary")
278+
279+
salary_holder.put_in_cache(numpy.asarray([numpy.nan]), period)
280+
281+
assert numpy.isnan(salary_holder.get_array(period)).all()

0 commit comments

Comments
 (0)