@@ -116,12 +116,22 @@ def get_coordinate(self, value: ScalarFloat | Array) -> ScalarFloat | Array:
116116class LogSpacedGrid (UniformContinuousGrid ):
117117 """A logarithmically spaced grid of continuous values.
118118
119+ Requires `start > 0`.
120+
119121 Example:
120122 --------
121123 Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`.
122124
123125 """
124126
127+ def __post_init__ (self ) -> None :
128+ _validate_continuous_grid (
129+ start = self .start ,
130+ stop = self .stop ,
131+ n_points = self .n_points ,
132+ requires_positive_start = True ,
133+ )
134+
125135 def to_jax (self ) -> Float1D :
126136 """Convert the grid to a Jax array."""
127137 return grid_coordinates .logspace (
@@ -188,9 +198,23 @@ def pass_points_at_runtime(self) -> bool:
188198 return self .points is None
189199
190200 def to_jax (self ) -> Float1D :
191- """Convert the grid to a Jax array."""
201+ """Convert the grid to a Jax array.
202+
203+ Raises `GridInitializationError` for runtime-supplied grids
204+ (`pass_points_at_runtime=True`). To get the substituted points,
205+ call `internal_regime.state_action_space(regime_params=...)` and
206+ read from `.states[name]` or `.continuous_actions[name]`.
207+ """
192208 if self .points is None :
193- return jnp .full (self .n_points , jnp .nan )
209+ raise GridInitializationError (
210+ f"IrregSpacedGrid declared with n_points={ self .n_points } and "
211+ f"no points; values are supplied at runtime via "
212+ f"params['<regime>']['<grid_name>']['points']. To get the "
213+ f"substituted points, call "
214+ f"`internal_regime.state_action_space(regime_params=...)` and "
215+ f"read from `.states[name]` or `.continuous_actions[name]`. "
216+ f"Use `.n_points` if only the shape is needed."
217+ )
194218 return jnp .asarray (self .points )
195219
196220 @overload
@@ -213,13 +237,16 @@ def _validate_continuous_grid(
213237 start : float ,
214238 stop : float ,
215239 n_points : int ,
240+ requires_positive_start : bool = False ,
216241) -> None :
217242 """Validate the continuous grid parameters.
218243
219244 Args:
220245 start: The start value of the grid.
221246 stop: The stop value of the grid.
222247 n_points: The number of points in the grid.
248+ requires_positive_start: If True, also require `start > 0` (used by
249+ log-spaced grids since `log(x)` is undefined for `x <= 0`).
223250
224251 Raises:
225252 GridInitializationError: If the grid parameters are invalid.
@@ -235,6 +262,15 @@ def _validate_continuous_grid(
235262 if not valid_stop_type :
236263 error_messages .append ("stop must be a scalar int or float value" )
237264
265+ # Reject NaN/inf early — `start >= stop` returns False for NaN, so an
266+ # un-finite start would otherwise pass silently and produce a broken grid.
267+ if valid_start_type and not jnp .isfinite (start ):
268+ error_messages .append (f"start must be finite, got { start } " )
269+ valid_start_type = False
270+ if valid_stop_type and not jnp .isfinite (stop ):
271+ error_messages .append (f"stop must be finite, got { stop } " )
272+ valid_stop_type = False
273+
238274 if not isinstance (n_points , int ) or n_points < 1 :
239275 error_messages .append (
240276 f"n_points must be an int greater than 0 but is { n_points } " ,
@@ -243,6 +279,12 @@ def _validate_continuous_grid(
243279 if valid_start_type and valid_stop_type and start >= stop :
244280 error_messages .append ("start must be less than stop" )
245281
282+ if valid_start_type and requires_positive_start and start <= 0 :
283+ error_messages .append (
284+ f"start must be > 0 for a log-spaced grid (got { start } ); "
285+ f"`log(x)` is undefined for `x <= 0`."
286+ )
287+
246288 if error_messages :
247289 msg = format_messages (error_messages )
248290 raise GridInitializationError (msg )
@@ -275,15 +317,24 @@ def _validate_irreg_spaced_grid(points: Sequence[float] | Float1D) -> None:
275317 f"Non-numeric elements found at indices: { non_numeric } "
276318 )
277319 else :
278- # Check that points are in ascending order
279- for i in range (len (points ) - 1 ):
280- if points [i ] >= points [i + 1 ]:
281- error_messages .append (
282- "Points must be in strictly ascending order. "
283- f"Found points[{ i } ]={ points [i ]} >= "
284- f"points[{ i + 1 } ]={ points [i + 1 ]} "
285- )
286- break
320+ # Reject NaN/inf — comparisons with NaN are False, so the
321+ # ascending-order check below would silently let them through.
322+ non_finite = [(i , p ) for i , p in enumerate (points ) if not jnp .isfinite (p )]
323+ if non_finite :
324+ error_messages .append (
325+ f"All elements of points must be finite. "
326+ f"Non-finite elements found at: { non_finite } "
327+ )
328+ else :
329+ # Check that points are in strictly ascending order
330+ for i in range (len (points ) - 1 ):
331+ if points [i ] >= points [i + 1 ]:
332+ error_messages .append (
333+ "Points must be in strictly ascending order. "
334+ f"Found points[{ i } ]={ points [i ]} >= "
335+ f"points[{ i + 1 } ]={ points [i + 1 ]} "
336+ )
337+ break
287338
288339 if error_messages :
289340 msg = format_messages (error_messages )
0 commit comments