Skip to content

Commit f101e75

Browse files
Merge pull request #252 from patrick-kidger/nan-memory-size
Fixed issue #250.
2 parents edd1250 + 0dfe3e5 commit f101e75

1 file changed

Lines changed: 16 additions & 24 deletions

File tree

diffrax/global_interpolation.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -318,40 +318,32 @@ def evaluate(
318318
if t1 is not None:
319319
return self.evaluate(t1, left=left) - self.evaluate(t0, left=left)
320320
t = t0 * self.direction
321-
ts_0 = self.ts[0]
322-
ts_1 = self.ts[self.ts_size - 1]
323-
pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1)
324-
eval_fn = ft.partial(self.__class__._evaluate, t=t, left=left)
325-
nan_fn = self.__class__._nan
326-
# Use cond to avoid generating nans unless we have to.
327-
out = lax.cond(pred, eval_fn, nan_fn, self)
321+
t_bounded = self._nan_if_out_of_bounds(t)
322+
out = self._get_local_interpolation(t_bounded, left).evaluate(
323+
t_bounded, left=left
324+
)
328325
keep = ft.partial(jnp.where, (t == self.t0_if_trivial) & (self.ts_size == 1))
329326
return jtu.tree_map(keep, self.y0_if_trivial, out)
330327

331328
@eqx.filter_jit
332329
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
333330
t = t * self.direction
331+
t = self._nan_if_out_of_bounds(t)
332+
out = self._get_local_interpolation(t, left).derivative(t, left=left)
333+
return (self.direction * out**ω).ω
334+
335+
def _nan_if_out_of_bounds(self, t):
334336
# Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid,
335337
# even if we throw it away because self.ts_size == 0.
336338
ts_0 = self.ts[0]
337339
ts_1 = self.ts[self.ts_size - 1]
338-
pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1)
339-
deriv_fn = ft.partial(self.__class__._derivative, t=t, left=left)
340-
nan_fn = self.__class__._nan
341-
# Use cond to avoid generating nans unless we have to.
342-
return lax.cond(pred, deriv_fn, nan_fn, self)
343-
344-
def _evaluate(self, t, left):
345-
return self._get_local_interpolation(t, left).evaluate(t, left=left)
346-
347-
def _derivative(self, t, left):
348-
out = self._get_local_interpolation(t, left).derivative(t, left=left)
349-
return (self.direction * out**ω).ω
350-
351-
def _nan(self):
352-
return jtu.tree_map(
353-
ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0_if_trivial
354-
)
340+
out_of_bounds = (self.ts_size <= 1) | (t < ts_0) | (t > ts_1)
341+
make_nans = lambda t: jnp.where(out_of_bounds, jnp.nan, t)
342+
identity = lambda t: t
343+
# Avoid making NaNs unless we have to, by using a cond.
344+
# (For the sake of JAX_DEBUG_NANS.)
345+
t = lax.cond(eqxi.unvmap_any(out_of_bounds), make_nans, identity, t)
346+
return t
355347

356348
@property
357349
def t0(self):

0 commit comments

Comments
 (0)