Skip to content

Improve ConstantStepSize incrementation#666

Merged
patrick-kidger merged 2 commits into
patrick-kidger:devfrom
jpbrodrick89:jpb/constantstepsizefix
Jul 30, 2025
Merged

Improve ConstantStepSize incrementation#666
patrick-kidger merged 2 commits into
patrick-kidger:devfrom
jpbrodrick89:jpb/constantstepsizefix

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

@jpbrodrick89 jpbrodrick89 commented Jul 16, 2025

ConstantStepSize no longer increments the time step by dt on each step, but now multiplies t1 - t0 by step / num_steps (which is almost identical but can be more accurate for small relative timesteps or a large number of timesteps). When step == num_steps the timestep is set exactly to t1.

Fully expect there to be numerous nit's here, especially for the int32 to result_type conversion (I could just use with jax.numpy_dtype_promotion("standard"): or alternatively cast each int individually). I also don't think eqx.error_if test is necessary as the clipping should prevent this, but it could potentially be an extra guardrail if clipping changes (again) or is removed in the future.

I can confirm that merging this with the jpb/ulp branch from #660 resolves the failing test.

It would be nice to support dt0=None but max_steps is not accessible by init at the moment.

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Nits only, I think this basically LGTM.

Comment on lines +42 to +43
step = jnp.asarray(1, dtype=jnp.int32)
num_steps = jnp.astype(jnp.ceil((t1 - t0) / dt0), jnp.int32)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this looks fine to me! Maybe to handle cases in which (t1 - t0) / dt0 is ever-so-slightly above an integer (due to floating point issues) then we should substract some amount of epsilon from the numerator?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the num_steps = jnp.astype(jnp.ceil((t1 - t0) / eqxi.nextafter(dt0)), jnp.int32)? We ideally want to gaurantee that dt0 is within floating point error to that specified, and this seems the most explicit way to do this (within 1 ULP).

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I think that looks like it'll probably work to me. For posterity, I tested this like so:

import numpy as np

def check(n: int, with_nextafter: bool):
    t_diff = np.array(3.0)
    for desired_num_steps in range(1, n):
        dt0 = t_diff / desired_num_steps
        if with_nextafter:
            dt0 = np.nextafter(dt0, float('inf'))
        calculated_num_steps = np.ceil(t_diff / dt0)
        if desired_num_steps != calculated_num_steps:
            print(desired_num_steps, with_nextafter)
            break

check(100, with_nextafter=False)
check(1000000, with_nextafter=True)
# 47 False

(Hopefully the same is still true on a GPU.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it works for negative dt0 (i.e. when t1 < t0). Should I put a where or use jnp.abs(jnp.ceil((t1 - t0) / eqxi.nextafter(jnp.abs(dt0)))).

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think negative dt0 should have already been normalised by the main integrate loop before this point! (Worth including a test for just to be sure, if we don't already have one.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see this now. There is a test in test_integrate called test_reverse_time which hopefully fits the bill as is. 🙂

Comment thread diffrax/_step_size_controller/constant.py Outdated
Comment thread diffrax/_step_size_controller/constant.py Outdated
Comment thread diffrax/_step_size_controller/constant.py Outdated
Comment thread diffrax/_step_size_controller/constant.py Outdated
@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Jul 18, 2025

Pyright has just gone mad for me and I have spent an hour trying to work out what is going wrong, it seems to think _ControllerState is RealScalarLike, would appreciate a fresh pair of eyes to help resolve, thanks. 🙏🏻

/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_step_size_controller/constant.py
  /Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_step_size_controller/constant.py:22:9 - error: Method "init" overrides class "AbstractStepSizeController" in an incompatible manner
    Return type mismatch: base method returns type "tuple[RealScalarLike, RealScalarLike]", override returns type "tuple[RealScalarLike, tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]]"
      "tuple[RealScalarLike, tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]]" is not assignable to "tuple[RealScalarLike, RealScalarLike]"
        Tuple entry 2 is incorrect type
          Type "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to type "RealScalarLike"
            "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to "bool"
            "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to "int"
            "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to "float"
            "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to "Array"
    ... (reportIncompatibleMethodOverride)
  /Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_step_size_controller/constant.py:46:9 - error: Method "adapt_step_size" overrides class "AbstractStepSizeController" in an incompatible manner
    Parameter 9 type mismatch: base parameter is type "RealScalarLike", override parameter is type "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]"
    Return type mismatch: base method returns type "tuple[BoolScalarLike, RealScalarLike, RealScalarLike, BoolScalarLike, RealScalarLike, RESULTS]", override returns type "tuple[bool, RealScalarLike, RealScalarLike, bool, tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike], RESULTS]"
      Type "RealScalarLike" is not assignable to type "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]"
        "Array" is not assignable to "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]"
      "tuple[bool, RealScalarLike, RealScalarLike, bool, tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike], RESULTS]" is not assignable to "tuple[BoolScalarLike, RealScalarLike, RealScalarLike, BoolScalarLike, RealScalarLike, RESULTS]"
        Tuple entry 5 is incorrect type
          Type "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to type "RealScalarLike"
            "tuple[IntScalarLike, IntScalarLike, RealScalarLike, RealScalarLike]" is not assignable to "bool"
    ... (reportIncompatibleMethodOverride)
``

Comment thread diffrax/_step_size_controller/constant.py Outdated
@patrick-kidger
Copy link
Copy Markdown
Owner

it seems to think _ControllerState is RealScalarLike

It's still set as the generic parameter in the class ConstantStepSize(...) definition :)

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

It's still set as the generic parameter in the class ConstantStepSize(...) definition :)

Thanks, I wasn't sufficiently well-versed in generics to catch this, good learning experience! Should be good to go now, I didn't run the full test suite on my machine but I verified it passes pre-commit (including pyright) and test_integrate.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Jul 24, 2025

test_text_progress_meter failure

We have minimum_increase=0.1 (i.e. 10%). However, for the tests the total number of steps is an exact multiple of ten, therefore there is sensitivity to floating point behaviour whether step n * num_steps % 10 == 0.0. To make this more predictable my suggestion is to modify minimum_increase to make it less sensitive setting, e.g. setting minimum_increase =0.1000001)1 fixes the first test but not the second (as the extra 0.2% now appears at every timestep) so expected would need to be updated in this case. WDYT?

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

The event tests are failing because my implementation relies on the value of t1 which is therefore not allowed to be jnp.inf, Forever running simulations with termination events seems like something you would definitely want to continue supporting. I currently don't store the original dt0 in the state and would have to reintroduce this in some way. My current somewhat hacky idea is to address is to set num_steps=-1, t1_sim=dt0 in such a case and then use a conditional (where again?) to return t1_next = t0 + t1_sim, potentialy renaming t1_sim as t1_sim_or_dt0. But maybe you can see a cleaner solution?

@patrick-kidger
Copy link
Copy Markdown
Owner

test_text_progress_meter

I don't think exact behaviour here should be too important to us. It's a general 'things are happening', not something precise. I'd be happy to adjust this test with either finer step sizes or a different output we assert against.

The event tests are failing because my implementation relies on the value of t1 which is therefore not allowed to be jnp.inf

Ah, this is a good catch. I think I agree with your approach - just special-casing this (jnp.isfinite) is the smartest thing I can see to do here too.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Jul 28, 2025

Should be fixed now, unless we need int64 for support for extremely large number of steps. :-)

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

Could you please double check the failing run? It looks to me that this might have been a SIGSEGV fail due to IO/a GHA runner issue. I did not see any failing tests in the log. Thanks.

@patrick-kidger patrick-kidger changed the base branch from main to dev July 30, 2025 18:16
@patrick-kidger patrick-kidger force-pushed the jpb/constantstepsizefix branch from 466f7c3 to e91c1e6 Compare July 30, 2025 18:21
@patrick-kidger patrick-kidger merged commit 514fc9e into patrick-kidger:dev Jul 30, 2025
1 of 2 checks passed
@patrick-kidger
Copy link
Copy Markdown
Owner

The failing run will be because of jax-ml/jax#30517 . So: this LGTM, and merged! 🎉

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

Happy with the changes, looks neat, thanks! Agree that using t1 directly is probably not necessary, I think it will only make a difference in extreme cases that one would expect floating point errors anyway (e.g. t0 = -1.0, t1 = 1e-16).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants