@@ -250,6 +250,7 @@ def set_input(
250250 return warnings .warn (warning_message , Warning )
251251 if self .variable .value_type in (float , int ) and isinstance (array , str ):
252252 array = tools .eval_expression (array )
253+ self ._raise_if_input_contains_nan (numpy .asarray (array ))
253254 simulation = getattr (self , "simulation" , None )
254255 if simulation is not None :
255256 if not hasattr (simulation , "_user_input_keys" ):
@@ -263,12 +264,29 @@ def set_input(
263264 and period .unit != self .variable .definition_period
264265 ):
265266 return self .variable .set_input (self , period , array )
266- return self ._set (period , array , branch_name )
267+ return self ._set (period , array , branch_name , validate_nan = True )
267268 finally :
268269 if simulation is not None :
269270 simulation ._user_input_contexts .pop ()
270271
271- def _to_array (self , value : Any ) -> ArrayLike :
272+ def _raise_if_input_contains_nan (self , value : ArrayLike ) -> None :
273+ if self .variable .value_type not in (float , int ):
274+ return
275+ value = numpy .asarray (value )
276+ try :
277+ if value .dtype .kind in ("O" , "S" , "U" ):
278+ value = value .astype (float )
279+ contains_nan = numpy .isnan (value ).any ()
280+ except (TypeError , ValueError ):
281+ return
282+ if contains_nan :
283+ raise ValueError (
284+ 'Unable to set value for variable "{}", as the input contains NaN values.' .format (
285+ self .variable .name ,
286+ )
287+ )
288+
289+ def _to_array (self , value : Any , validate_nan : bool = False ) -> ArrayLike :
272290 if not isinstance (value , numpy .ndarray ):
273291 value = numpy .asarray (value )
274292 if value .ndim == 0 :
@@ -284,6 +302,8 @@ def _to_array(self, value: Any) -> ArrayLike:
284302 self .population .entity .plural ,
285303 )
286304 )
305+ if validate_nan :
306+ self ._raise_if_input_contains_nan (value )
287307 if self .variable .value_type == Enum :
288308 original_value = value
289309 value = self .variable .possible_values .encode (value )
@@ -301,16 +321,22 @@ def _to_array(self, value: Any) -> ArrayLike:
301321 value .dtype ,
302322 )
303323 )
324+ if validate_nan :
325+ self ._raise_if_input_contains_nan (value )
304326 return value
305327
306328 def _set (
307- self , period : Period , value : ArrayLike , branch_name : str = "default"
329+ self ,
330+ period : Period ,
331+ value : ArrayLike ,
332+ branch_name : str = "default" ,
333+ validate_nan : bool = False ,
308334 ) -> None :
309335 simulation = getattr (self , "simulation" , None )
310336 user_input_contexts = getattr (simulation , "_user_input_contexts" , None )
311337 if user_input_contexts and branch_name == "default" :
312338 branch_name = user_input_contexts [- 1 ]
313- value = self ._to_array (value )
339+ value = self ._to_array (value , validate_nan = validate_nan )
314340 if self .variable .definition_period != periods .ETERNITY :
315341 if period is None :
316342 raise ValueError (
0 commit comments