1111
1212from .custom_types import Array , DenseInfos , Int , PyTree , Scalar
1313from .local_interpolation import AbstractLocalInterpolation
14- from .misc import fill_forward , left_broadcast_to
14+ from .misc import fill_forward , left_broadcast_to , linear_rescale
1515from .path import AbstractPath
1616
1717
@@ -124,10 +124,10 @@ def _index(_ys):
124124 next_ys = (self .ys ** ω )[index + 1 ].ω
125125 prev_t = self .ts [index ]
126126 next_t = self .ts [index + 1 ]
127- diff_t = next_t - prev_t
128-
129127 return (
130- prev_ys ** ω + (next_ys ** ω - prev_ys ** ω ) * (fractional_part / diff_t )
128+ prev_ys ** ω
129+ + (next_ys ** ω - prev_ys ** ω )
130+ * (linear_rescale (prev_t , fractional_part , next_t ))
131131 ).ω
132132
133133 @eqx .filter_jit
@@ -407,7 +407,6 @@ def _linear_interpolation_forward(
407407 Tuple [Array ["channels" :...], Array ["channels" :...]], # noqa: F821
408408 Array ["channels" :...], # noqa: F821
409409]:
410-
411410 prev_ti , prev_yi = carry
412411 ti , yi , next_ti , next_yi = value
413412 cond = jnp .isnan (yi )
@@ -426,7 +425,6 @@ def _linear_interpolation(
426425 ys : Array ["times" , "channels" :...], # noqa: F821
427426 replace_nans_at_start : Optional [Array ["channels" :...]] = None , # noqa: F821
428427) -> Array ["times" , "channels" :...]: # noqa: F821
429-
430428 ts = left_broadcast_to (ts , ys .shape )
431429
432430 if replace_nans_at_start is None :
@@ -599,7 +597,6 @@ def _hermite_forward(
599597 Array ["channels" :...], # noqa: F821
600598 ],
601599]:
602-
603600 prev_ti , prev_yi , prev_deriv_i = carry
604601 ti , yi , next_ti , next_yi = value
605602 first_deriv_i = (next_yi - yi ) / (next_ti - ti )
0 commit comments