From ec1bfc9e9476f64e7c345e2c0d833452b268253b Mon Sep 17 00:00:00 2001 From: Sven Berger Date: Mon, 11 May 2026 17:36:30 +0200 Subject: [PATCH 1/5] Fix steady state callback interval handling --- src/callbacks/steady_state_reached.jl | 20 +++++++++++++++++--- test/callbacks/steady_state_reached.jl | 25 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index dfbb8ec75f..386fe97271 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -26,13 +26,16 @@ end function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6) - ELTYPE = eltype(abstol) - abstol, reltol = promote(abstol, reltol) - if dt > 0 && interval > 0 throw(ArgumentError("setting both `interval` and `dt` is not supported")) end + interval_size > 0 || + throw(ArgumentError("`interval_size` must be positive")) + + abstol, reltol = float.(promote(abstol, reltol)) + ELTYPE = typeof(abstol) + if dt > 0 interval = convert(ELTYPE, dt) end @@ -67,10 +70,21 @@ function (cb::SteadyStateReachedCallback{Int})(integrator) end # `condition` (`DiscreteCallback`) +function (steady_state_callback::SteadyStateReachedCallback{Int})(vu_ode, t, integrator) + condition_steady_state_interval(steady_state_callback, integrator) || return false + + return steady_state_condition!(steady_state_callback, integrator) +end + function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrator) return steady_state_condition!(steady_state_callback, integrator) end +@inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, integrator) + return cb.interval == 0 || + condition_integrator_interval(integrator, cb.interval; save_final_solution=false) +end + @inline function steady_state_condition!(cb, integrator) (; abstol, reltol, previous_ekin, interval_size) = cb diff --git a/test/callbacks/steady_state_reached.jl b/test/callbacks/steady_state_reached.jl index 0b0f99fdc7..56e2f86d93 100644 --- a/test/callbacks/steady_state_reached.jl +++ b/test/callbacks/steady_state_reached.jl @@ -47,5 +47,30 @@ @testset "Illegal Input" begin error_str = "setting both `interval` and `dt` is not supported" @test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1) + + error_str = "`interval_size` must be positive" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval_size=0) + end + + @testset "constructor" begin + callback = SteadyStateReachedCallback(interval=1, abstol=1, reltol=1) + steady_state_cb = callback.affect! + + @test steady_state_cb.abstol === 1.0 + @test steady_state_cb.reltol === 1.0 + @test steady_state_cb.previous_ekin == [Inf] + end + + @testset "condition interval" begin + function mock_integrator(naccept) + return (; stats=(; naccept)) + end + + callback = SteadyStateReachedCallback(interval=0).affect! + @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(1)) + + callback = SteadyStateReachedCallback(interval=10).affect! + @test !TrixiParticles.condition_steady_state_interval(callback, mock_integrator(9)) + @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(10)) end end From 056e4986b41669b6b6c02da39078ac1b315e0141 Mon Sep 17 00:00:00 2001 From: Sven Berger Date: Tue, 12 May 2026 12:44:40 +0200 Subject: [PATCH 2/5] format --- src/callbacks/steady_state_reached.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index 386fe97271..0a541a0d45 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -80,7 +80,8 @@ function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrat return steady_state_condition!(steady_state_callback, integrator) end -@inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, integrator) +@inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, + integrator) return cb.interval == 0 || condition_integrator_interval(integrator, cb.interval; save_final_solution=false) end From 17433258f781ae5d05c510f862026ac1fac2e61b Mon Sep 17 00:00:00 2001 From: Sven Berger Date: Tue, 12 May 2026 13:08:03 +0200 Subject: [PATCH 3/5] Fix steady state callback interval size and update expected coordinates in tests --- examples/preprocessing/packing_2d.jl | 2 +- src/callbacks/steady_state_reached.jl | 4 ++-- test/examples/examples.jl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/packing_2d.jl b/examples/preprocessing/packing_2d.jl index 02951323ef..fcf0d40c64 100644 --- a/examples/preprocessing/packing_2d.jl +++ b/examples/preprocessing/packing_2d.jl @@ -83,7 +83,7 @@ tspan = (0, 10.0) ode = semidiscretize(semi, tspan) # Use this callback to stop the simulation when it is sufficiently close to a steady state -steady_state = SteadyStateReachedCallback(; interval=10, interval_size=200, +steady_state = SteadyStateReachedCallback(; interval=10, interval_size=20, abstol=1.0e-7, reltol=1.0e-6) info_callback = InfoCallback(interval=50) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index 0a541a0d45..ef7529c7ca 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -11,8 +11,8 @@ where `ekin` is the total kinetic energy of the simulation. - `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms of integration time by adding additional `tstops` (note that this may change the solution). -- `interval_size`: The interval in which the change of the kinetic energy is considered. - `interval_size` is a (integer) multiple of `interval` or `dt`. +- `interval_size`: The number of callback evaluations over which the change of the + kinetic energy is considered. - `abstol`: Absolute tolerance. - `reltol`: Relative tolerance. """ diff --git a/test/examples/examples.jl b/test/examples/examples.jl index 68e79c25cc..5f4559b651 100644 --- a/test/examples/examples.jl +++ b/test/examples/examples.jl @@ -415,8 +415,8 @@ joinpath(examples_dir(), "preprocessing", "packing_2d.jl"), particle_spacing=0.4) - expected_coordinates = [-0.540548 -0.189943 0.191664 0.542741 -0.629391 -0.196159 0.197725 0.63081 -0.629447 -0.196158 0.19779 0.631121 -0.540483 -0.190015 0.191345 0.540433; - -0.541127 -0.630201 -0.630119 -0.539294 -0.190697 -0.196942 -0.196916 -0.190324 0.190875 0.197074 0.196955 0.190973 0.541206 0.630323 0.630178 0.541314] + expected_coordinates = [-0.540548 -0.189649 0.19137 0.542738 -0.63052 -0.196161 0.197719 0.631925 -0.630577 -0.19616 0.197787 0.632248 -0.540485 -0.189726 0.191038 0.540435; + -0.541127 -0.63133 -0.631244 -0.539296 -0.190402 -0.196943 -0.196917 -0.190013 0.190581 0.197074 0.19695 0.190685 0.541204 0.631451 0.6313 0.541311] @test isapprox(packed_ic.coordinates, expected_coordinates, atol=1e-5) end From f708c5cf72a90992bea98722cbfcf2bbb97e7f18 Mon Sep 17 00:00:00 2001 From: Sven Berger Date: Tue, 12 May 2026 16:12:30 +0200 Subject: [PATCH 4/5] Update src/callbacks/steady_state_reached.jl Co-authored-by: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> --- src/callbacks/steady_state_reached.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index 66db8b2997..cce5871d31 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -30,8 +30,9 @@ function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, throw(ArgumentError("setting both `interval` and `dt` is not supported")) end - interval_size > 0 || + if interval_size <= 0 throw(ArgumentError("`interval_size` must be positive")) + end abstol, reltol = float.(promote(abstol, reltol)) ELTYPE = typeof(abstol) From 8db26fbf321e0de62b820fff054a0cee2505501f Mon Sep 17 00:00:00 2001 From: Sven Berger Date: Sun, 17 May 2026 20:32:13 +0200 Subject: [PATCH 5/5] fix consistency issues --- src/callbacks/steady_state_reached.jl | 84 +++++++++++++++++++------- test/callbacks/steady_state_reached.jl | 61 +++++++++++++++---- 2 files changed, 112 insertions(+), 33 deletions(-) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index cce5871d31..ea246e750b 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -8,9 +8,11 @@ where `ekin` is the total kinetic energy of the simulation. # Keywords - `interval=0`: Check steady state condition every `interval` time steps. + A value of `0` disables step-interval checks. - `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms of integration time by adding additional `tstops` (note that this may change the solution). +- Either `interval` or `dt` must be set to something larger than 0. - `interval_size`: The number of callback evaluations over which the change of the kinetic energy is considered. - `abstol`: Absolute tolerance. @@ -26,10 +28,18 @@ end function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6) + if interval < 0 + throw(ArgumentError("`interval` must be non-negative")) + end + if dt > 0 && interval > 0 throw(ArgumentError("setting both `interval` and `dt` is not supported")) end + if dt <= 0 && interval == 0 + throw(ArgumentError("either `interval` or `dt` must be set to a positive value")) + end + if interval_size <= 0 throw(ArgumentError("`interval_size` must be positive")) end @@ -37,30 +47,58 @@ function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, abstol, reltol = float.(promote(abstol, reltol)) ELTYPE = typeof(abstol) - if dt > 0 - interval = convert(ELTYPE, dt) + interval_ = if dt > 0 + convert(ELTYPE, dt) + else + Int(interval) end - steady_state_callback = SteadyStateReachedCallback(interval, abstol, reltol, + steady_state_callback = SteadyStateReachedCallback(interval_, abstol, reltol, [convert(ELTYPE, Inf)], interval_size) if dt > 0 - return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false), - final_affect=true) + return PeriodicCallback(steady_state_callback, dt, + initialize=(initialize_steady_state_callback!), + save_positions=(false, false)) else return DiscreteCallback(steady_state_callback, steady_state_callback, - save_positions=(false, false)) + save_positions=(false, false), + initialize=(initialize_steady_state_callback!)) end end +function initialize_steady_state_callback!(cb, u, t, integrator) + # The `SteadyStateReachedCallback` is either `cb.affect!` (with `DiscreteCallback`) + # or `cb.affect!.affect!` (with `PeriodicCallback`). + # Let recursive dispatch handle this. + initialize_steady_state_callback!(cb.affect!, u, t, integrator) +end + +function initialize_steady_state_callback!(cb::SteadyStateReachedCallback, u, t, integrator) + semi = integrator.p.semi + set_callbacks_used!(semi, integrator) + + empty!(cb.previous_ekin) + push!(cb.previous_ekin, convert(eltype(cb.previous_ekin), Inf)) + + return nothing +end + # `affect!` (`PeriodicCallback`) function (cb::SteadyStateReachedCallback)(integrator) - steady_state_condition!(cb, integrator) || return nothing + if !steady_state_condition!(cb, integrator) + u_modified!(integrator, false) + return nothing + end print_summary(integrator) terminate!(integrator) + + u_modified!(integrator, false) + + return nothing end # `affect!` (`DiscreteCallback`) @@ -68,6 +106,10 @@ function (cb::SteadyStateReachedCallback{Int})(integrator) print_summary(integrator) terminate!(integrator) + + u_modified!(integrator, false) + + return nothing end # `condition` (`DiscreteCallback`) @@ -83,8 +125,8 @@ end @inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, integrator) - return cb.interval == 0 || - condition_integrator_interval(integrator, cb.interval; save_final_solution=false) + return condition_integrator_interval(integrator, cb.interval; + save_final_solution=false) end @inline function steady_state_condition!(cb, integrator) @@ -129,8 +171,8 @@ function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallb cb_ = cb.affect! - print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol, - ")") + print(io, "SteadyStateReachedCallback(interval=", cb_.interval, + ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") end function Base.show(io::IO, @@ -140,8 +182,8 @@ function Base.show(io::IO, cb_ = cb.affect!.affect! - print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol, - ")") + print(io, "SteadyStateReachedCallback(dt=", cb_.interval, + ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") end function Base.show(io::IO, ::MIME"text/plain", @@ -153,10 +195,10 @@ function Base.show(io::IO, ::MIME"text/plain", else cb_ = cb.affect! - setup = ["absolute tolerance" => cb_.abstol, - "relative tolerance" => cb_.reltol, - "interval" => cb_.interval, - "interval size" => cb_.interval_size] + setup = ["interval" => cb_.interval, + "interval size" => cb_.interval_size, + "absolute tolerance" => cb_.abstol, + "relative tolerance" => cb_.reltol] summary_box(io, "SteadyStateReachedCallback", setup) end end @@ -171,10 +213,10 @@ function Base.show(io::IO, ::MIME"text/plain", else cb_ = cb.affect!.affect! - setup = ["absolute tolerance" => cb_.abstol, - "relative tolerance" => cb_.reltol, - "interval" => cb_.interval, - "interval_size" => cb_.interval_size] + setup = ["dt" => cb_.interval, + "interval size" => cb_.interval_size, + "absolute tolerance" => cb_.abstol, + "relative tolerance" => cb_.reltol] summary_box(io, "SteadyStateReachedCallback", setup) end end diff --git a/test/callbacks/steady_state_reached.jl b/test/callbacks/steady_state_reached.jl index 56e2f86d93..71bda0bcb9 100644 --- a/test/callbacks/steady_state_reached.jl +++ b/test/callbacks/steady_state_reached.jl @@ -1,55 +1,68 @@ @testset verbose=true "SteadyStateReachedCallback" begin @testset verbose=true "show" begin - # Default - callback0 = SteadyStateReachedCallback() + callback0 = SteadyStateReachedCallback(interval=1) - show_compact = "SteadyStateReachedCallback(abstol=1.0e-8, reltol=1.0e-6)" + show_compact = "SteadyStateReachedCallback(interval=1, abstol=1.0e-8, reltol=1.0e-6)" @test repr(callback0) == show_compact show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ + │ interval: ……………………………………………………… 1.0 │ + │ interval size: ………………………………………… 10.0 │ │ absolute tolerance: …………………………… 1.0e-8 │ │ relative tolerance: …………………………… 1.0e-6 │ - │ interval: ……………………………………………………… 0.0 │ - │ interval size: ………………………………………… 10.0 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback0) == show_box callback1 = SteadyStateReachedCallback(interval=11) + show_compact = "SteadyStateReachedCallback(interval=11, abstol=1.0e-8, reltol=1.0e-6)" + @test repr(callback1) == show_compact + show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ - │ absolute tolerance: …………………………… 1.0e-8 │ - │ relative tolerance: …………………………… 1.0e-6 │ │ interval: ……………………………………………………… 11.0 │ │ interval size: ………………………………………… 10.0 │ + │ absolute tolerance: …………………………… 1.0e-8 │ + │ relative tolerance: …………………………… 1.0e-6 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback1) == show_box callback2 = SteadyStateReachedCallback(dt=1.2) + show_compact = "SteadyStateReachedCallback(dt=1.2, abstol=1.0e-8, reltol=1.0e-6)" + @test repr(callback2) == show_compact + show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ + │ dt: ……………………………………………………………………… 1.2 │ + │ interval size: ………………………………………… 10.0 │ │ absolute tolerance: …………………………… 1.0e-8 │ │ relative tolerance: …………………………… 1.0e-6 │ - │ interval: ……………………………………………………… 1.2 │ - │ interval_size: ………………………………………… 10.0 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback2) == show_box end @testset "Illegal Input" begin + error_str = "either `interval` or `dt` must be set to a positive value" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback() + error_str = "setting both `interval` and `dt` is not supported" @test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1) error_str = "`interval_size` must be positive" - @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval_size=0) + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=1, + interval_size=0) + + error_str = "`interval` must be non-negative" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-1) + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-2) end @testset "constructor" begin @@ -59,14 +72,38 @@ @test steady_state_cb.abstol === 1.0 @test steady_state_cb.reltol === 1.0 @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(interval=Int32(2)) + steady_state_cb = callback.affect! + @test steady_state_cb.interval === 2 + + push!(steady_state_cb.previous_ekin, 1.0) + semi = (; integrate_tlsph=Ref(true), update_callback_used=Ref(false)) + integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) + callback.initialize(callback, nothing, 0.0, integrator) + + @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(dt=0.1) + steady_state_cb = callback.affect!.affect! + push!(steady_state_cb.previous_ekin, 1.0) + integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) + TrixiParticles.initialize_steady_state_callback!(callback.affect!, nothing, 0.0, + integrator) + + @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(dt=0.1, interval=0) + @test callback.affect!.affect!.interval === 0.1 end @testset "condition interval" begin function mock_integrator(naccept) - return (; stats=(; naccept)) + return (; stats=(; naccept), t=0.0, sol=(; prob=(; tspan=(0.0, 1.0))), + opts=(; tstops=[1.0], maxiters=100), iter=naccept) end - callback = SteadyStateReachedCallback(interval=0).affect! + callback = SteadyStateReachedCallback(interval=1).affect! @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(1)) callback = SteadyStateReachedCallback(interval=10).affect!