Improve ConstantStepSize incrementation#666
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
Awesome! Nits only, I think this basically LGTM.
| step = jnp.asarray(1, dtype=jnp.int32) | ||
| num_steps = jnp.astype(jnp.ceil((t1 - t0) / dt0), jnp.int32) |
There was a problem hiding this comment.
Actually, I think this looks fine to me! Maybe to handle cases in which (t1 - t0) / dt0 is ever-so-slightly above an integer (due to floating point issues) then we should substract some amount of epsilon from the numerator?
There was a problem hiding this comment.
How about the num_steps = jnp.astype(jnp.ceil((t1 - t0) / eqxi.nextafter(dt0)), jnp.int32)? We ideally want to gaurantee that dt0 is within floating point error to that specified, and this seems the most explicit way to do this (within 1 ULP).
There was a problem hiding this comment.
Yup, I think that looks like it'll probably work to me. For posterity, I tested this like so:
import numpy as np
def check(n: int, with_nextafter: bool):
t_diff = np.array(3.0)
for desired_num_steps in range(1, n):
dt0 = t_diff / desired_num_steps
if with_nextafter:
dt0 = np.nextafter(dt0, float('inf'))
calculated_num_steps = np.ceil(t_diff / dt0)
if desired_num_steps != calculated_num_steps:
print(desired_num_steps, with_nextafter)
break
check(100, with_nextafter=False)
check(1000000, with_nextafter=True)
# 47 False(Hopefully the same is still true on a GPU.)
There was a problem hiding this comment.
Not sure it works for negative dt0 (i.e. when t1 < t0). Should I put a where or use jnp.abs(jnp.ceil((t1 - t0) / eqxi.nextafter(jnp.abs(dt0)))).
There was a problem hiding this comment.
I think negative dt0 should have already been normalised by the main integrate loop before this point! (Worth including a test for just to be sure, if we don't already have one.)
There was a problem hiding this comment.
Yes, I see this now. There is a test in test_integrate called test_reverse_time which hopefully fits the bill as is. 🙂
|
Pyright has just gone mad for me and I have spent an hour trying to work out what is going wrong, it seems to think |
It's still set as the generic parameter in the |
Thanks, I wasn't sufficiently well-versed in generics to catch this, good learning experience! Should be good to go now, I didn't run the full test suite on my machine but I verified it passes pre-commit (including pyright) and |
|
We have |
|
The event tests are failing because my implementation relies on the value of |
I don't think exact behaviour here should be too important to us. It's a general 'things are happening', not something precise. I'd be happy to adjust this test with either finer step sizes or a different output we assert against.
Ah, this is a good catch. I think I agree with your approach - just special-casing this ( |
|
Should be fixed now, unless we need int64 for support for extremely large number of steps. :-) |
|
Could you please double check the failing run? It looks to me that this might have been a SIGSEGV fail due to IO/a GHA runner issue. I did not see any failing tests in the log. Thanks. |
466f7c3 to
e91c1e6
Compare
|
The failing run will be because of jax-ml/jax#30517 . So: this LGTM, and merged! 🎉 |
|
Happy with the changes, looks neat, thanks! Agree that using |
ConstantStepSize no longer increments the time step by
dton each step, but now multipliest1 - t0bystep / num_steps(which is almost identical but can be more accurate for small relative timesteps or a large number of timesteps). Whenstep == num_stepsthe timestep is set exactly tot1.Fully expect there to be numerous nit's here, especially for the int32 to result_type conversion (I could just use
with jax.numpy_dtype_promotion("standard"):or alternatively cast each int individually). I also don't thinkeqx.error_iftest is necessary as the clipping should prevent this, but it could potentially be an extra guardrail if clipping changes (again) or is removed in the future.I can confirm that merging this with the jpb/ulp branch from #660 resolves the failing test.
It would be nice to support
dt0=Nonebutmax_stepsis not accessible byinitat the moment.