@@ -318,40 +318,32 @@ def evaluate(
318318 if t1 is not None :
319319 return self .evaluate (t1 , left = left ) - self .evaluate (t0 , left = left )
320320 t = t0 * self .direction
321- ts_0 = self .ts [0 ]
322- ts_1 = self .ts [self .ts_size - 1 ]
323- pred = (self .ts_size > 1 ) & (t >= ts_0 ) & (t <= ts_1 )
324- eval_fn = ft .partial (self .__class__ ._evaluate , t = t , left = left )
325- nan_fn = self .__class__ ._nan
326- # Use cond to avoid generating nans unless we have to.
327- out = lax .cond (pred , eval_fn , nan_fn , self )
321+ t_bounded = self ._nan_if_out_of_bounds (t )
322+ out = self ._get_local_interpolation (t_bounded , left ).evaluate (
323+ t_bounded , left = left
324+ )
328325 keep = ft .partial (jnp .where , (t == self .t0_if_trivial ) & (self .ts_size == 1 ))
329326 return jtu .tree_map (keep , self .y0_if_trivial , out )
330327
331328 @eqx .filter_jit
332329 def derivative (self , t : Scalar , left : bool = True ) -> PyTree :
333330 t = t * self .direction
331+ t = self ._nan_if_out_of_bounds (t )
332+ out = self ._get_local_interpolation (t , left ).derivative (t , left = left )
333+ return (self .direction * out ** ω ).ω
334+
335+ def _nan_if_out_of_bounds (self , t ):
334336 # Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid,
335337 # even if we throw it away because self.ts_size == 0.
336338 ts_0 = self .ts [0 ]
337339 ts_1 = self .ts [self .ts_size - 1 ]
338- pred = (self .ts_size > 1 ) & (t >= ts_0 ) & (t <= ts_1 )
339- deriv_fn = ft .partial (self .__class__ ._derivative , t = t , left = left )
340- nan_fn = self .__class__ ._nan
341- # Use cond to avoid generating nans unless we have to.
342- return lax .cond (pred , deriv_fn , nan_fn , self )
343-
344- def _evaluate (self , t , left ):
345- return self ._get_local_interpolation (t , left ).evaluate (t , left = left )
346-
347- def _derivative (self , t , left ):
348- out = self ._get_local_interpolation (t , left ).derivative (t , left = left )
349- return (self .direction * out ** ω ).ω
350-
351- def _nan (self ):
352- return jtu .tree_map (
353- ft .partial (jnp .full_like , fill_value = jnp .nan ), self .y0_if_trivial
354- )
340+ out_of_bounds = (self .ts_size <= 1 ) | (t < ts_0 ) | (t > ts_1 )
341+ make_nans = lambda t : jnp .where (out_of_bounds , jnp .nan , t )
342+ identity = lambda t : t
343+ # Avoid making NaNs unless we have to, by using a cond.
344+ # (For the sake of JAX_DEBUG_NANS.)
345+ t = lax .cond (eqxi .unvmap_any (out_of_bounds ), make_nans , identity , t )
346+ return t
355347
356348 @property
357349 def t0 (self ):
0 commit comments