diff --git a/examples/preprocessing/packing_2d.jl b/examples/preprocessing/packing_2d.jl index 02951323ef..fcf0d40c64 100644 --- a/examples/preprocessing/packing_2d.jl +++ b/examples/preprocessing/packing_2d.jl @@ -83,7 +83,7 @@ tspan = (0, 10.0) ode = semidiscretize(semi, tspan) # Use this callback to stop the simulation when it is sufficiently close to a steady state -steady_state = SteadyStateReachedCallback(; interval=10, interval_size=200, +steady_state = SteadyStateReachedCallback(; interval=10, interval_size=20, abstol=1.0e-7, reltol=1.0e-6) info_callback = InfoCallback(interval=50) diff --git a/src/callbacks/steady_state_reached.jl b/src/callbacks/steady_state_reached.jl index bca0faafea..ea246e750b 100644 --- a/src/callbacks/steady_state_reached.jl +++ b/src/callbacks/steady_state_reached.jl @@ -8,11 +8,13 @@ where `ekin` is the total kinetic energy of the simulation. # Keywords - `interval=0`: Check steady state condition every `interval` time steps. + A value of `0` disables step-interval checks. - `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms of integration time by adding additional `tstops` (note that this may change the solution). -- `interval_size`: The interval in which the change of the kinetic energy is considered. - `interval_size` is a (integer) multiple of `interval` or `dt`. +- Either `interval` or `dt` must be set to something larger than 0. +- `interval_size`: The number of callback evaluations over which the change of the + kinetic energy is considered. - `abstol`: Absolute tolerance. - `reltol`: Relative tolerance. """ @@ -26,37 +28,77 @@ end function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6) - ELTYPE = eltype(abstol) - abstol, reltol = promote(abstol, reltol) + if interval < 0 + throw(ArgumentError("`interval` must be non-negative")) + end if dt > 0 && interval > 0 throw(ArgumentError("setting both `interval` and `dt` is not supported")) end - if dt > 0 - interval = convert(ELTYPE, dt) + if dt <= 0 && interval == 0 + throw(ArgumentError("either `interval` or `dt` must be set to a positive value")) + end + + if interval_size <= 0 + throw(ArgumentError("`interval_size` must be positive")) + end + + abstol, reltol = float.(promote(abstol, reltol)) + ELTYPE = typeof(abstol) + + interval_ = if dt > 0 + convert(ELTYPE, dt) + else + Int(interval) end - steady_state_callback = SteadyStateReachedCallback(interval, abstol, reltol, + steady_state_callback = SteadyStateReachedCallback(interval_, abstol, reltol, [convert(ELTYPE, Inf)], interval_size) if dt > 0 - return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false), - final_affect=true) + return PeriodicCallback(steady_state_callback, dt, + initialize=(initialize_steady_state_callback!), + save_positions=(false, false)) else return DiscreteCallback(steady_state_callback, steady_state_callback, - save_positions=(false, false)) + save_positions=(false, false), + initialize=(initialize_steady_state_callback!)) end end +function initialize_steady_state_callback!(cb, u, t, integrator) + # The `SteadyStateReachedCallback` is either `cb.affect!` (with `DiscreteCallback`) + # or `cb.affect!.affect!` (with `PeriodicCallback`). + # Let recursive dispatch handle this. + initialize_steady_state_callback!(cb.affect!, u, t, integrator) +end + +function initialize_steady_state_callback!(cb::SteadyStateReachedCallback, u, t, integrator) + semi = integrator.p.semi + set_callbacks_used!(semi, integrator) + + empty!(cb.previous_ekin) + push!(cb.previous_ekin, convert(eltype(cb.previous_ekin), Inf)) + + return nothing +end + # `affect!` (`PeriodicCallback`) function (cb::SteadyStateReachedCallback)(integrator) - steady_state_condition!(cb, integrator) || return nothing + if !steady_state_condition!(cb, integrator) + u_modified!(integrator, false) + return nothing + end print_summary(integrator) terminate!(integrator) + + u_modified!(integrator, false) + + return nothing end # `affect!` (`DiscreteCallback`) @@ -64,13 +106,29 @@ function (cb::SteadyStateReachedCallback{Int})(integrator) print_summary(integrator) terminate!(integrator) + + u_modified!(integrator, false) + + return nothing end # `condition` (`DiscreteCallback`) +function (steady_state_callback::SteadyStateReachedCallback{Int})(vu_ode, t, integrator) + condition_steady_state_interval(steady_state_callback, integrator) || return false + + return steady_state_condition!(steady_state_callback, integrator) +end + function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrator) return steady_state_condition!(steady_state_callback, integrator) end +@inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, + integrator) + return condition_integrator_interval(integrator, cb.interval; + save_final_solution=false) +end + @inline function steady_state_condition!(cb, integrator) (; abstol, reltol, previous_ekin, interval_size) = cb @@ -113,8 +171,8 @@ function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallb cb_ = cb.affect! - print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol, - ")") + print(io, "SteadyStateReachedCallback(interval=", cb_.interval, + ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") end function Base.show(io::IO, @@ -124,8 +182,8 @@ function Base.show(io::IO, cb_ = cb.affect!.affect! - print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol, - ")") + print(io, "SteadyStateReachedCallback(dt=", cb_.interval, + ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") end function Base.show(io::IO, ::MIME"text/plain", @@ -137,10 +195,10 @@ function Base.show(io::IO, ::MIME"text/plain", else cb_ = cb.affect! - setup = ["absolute tolerance" => cb_.abstol, - "relative tolerance" => cb_.reltol, - "interval" => cb_.interval, - "interval size" => cb_.interval_size] + setup = ["interval" => cb_.interval, + "interval size" => cb_.interval_size, + "absolute tolerance" => cb_.abstol, + "relative tolerance" => cb_.reltol] summary_box(io, "SteadyStateReachedCallback", setup) end end @@ -155,10 +213,10 @@ function Base.show(io::IO, ::MIME"text/plain", else cb_ = cb.affect!.affect! - setup = ["absolute tolerance" => cb_.abstol, - "relative tolerance" => cb_.reltol, - "interval" => cb_.interval, - "interval_size" => cb_.interval_size] + setup = ["dt" => cb_.interval, + "interval size" => cb_.interval_size, + "absolute tolerance" => cb_.abstol, + "relative tolerance" => cb_.reltol] summary_box(io, "SteadyStateReachedCallback", setup) end end diff --git a/test/callbacks/steady_state_reached.jl b/test/callbacks/steady_state_reached.jl index 0b0f99fdc7..71bda0bcb9 100644 --- a/test/callbacks/steady_state_reached.jl +++ b/test/callbacks/steady_state_reached.jl @@ -1,51 +1,113 @@ @testset verbose=true "SteadyStateReachedCallback" begin @testset verbose=true "show" begin - # Default - callback0 = SteadyStateReachedCallback() + callback0 = SteadyStateReachedCallback(interval=1) - show_compact = "SteadyStateReachedCallback(abstol=1.0e-8, reltol=1.0e-6)" + show_compact = "SteadyStateReachedCallback(interval=1, abstol=1.0e-8, reltol=1.0e-6)" @test repr(callback0) == show_compact show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ + │ interval: ……………………………………………………… 1.0 │ + │ interval size: ………………………………………… 10.0 │ │ absolute tolerance: …………………………… 1.0e-8 │ │ relative tolerance: …………………………… 1.0e-6 │ - │ interval: ……………………………………………………… 0.0 │ - │ interval size: ………………………………………… 10.0 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback0) == show_box callback1 = SteadyStateReachedCallback(interval=11) + show_compact = "SteadyStateReachedCallback(interval=11, abstol=1.0e-8, reltol=1.0e-6)" + @test repr(callback1) == show_compact + show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ - │ absolute tolerance: …………………………… 1.0e-8 │ - │ relative tolerance: …………………………… 1.0e-6 │ │ interval: ……………………………………………………… 11.0 │ │ interval size: ………………………………………… 10.0 │ + │ absolute tolerance: …………………………… 1.0e-8 │ + │ relative tolerance: …………………………… 1.0e-6 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback1) == show_box callback2 = SteadyStateReachedCallback(dt=1.2) + show_compact = "SteadyStateReachedCallback(dt=1.2, abstol=1.0e-8, reltol=1.0e-6)" + @test repr(callback2) == show_compact + show_box = """ ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ │ SteadyStateReachedCallback │ │ ══════════════════════════ │ + │ dt: ……………………………………………………………………… 1.2 │ + │ interval size: ………………………………………… 10.0 │ │ absolute tolerance: …………………………… 1.0e-8 │ │ relative tolerance: …………………………… 1.0e-6 │ - │ interval: ……………………………………………………… 1.2 │ - │ interval_size: ………………………………………… 10.0 │ └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" @test repr("text/plain", callback2) == show_box end @testset "Illegal Input" begin + error_str = "either `interval` or `dt` must be set to a positive value" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback() + error_str = "setting both `interval` and `dt` is not supported" @test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1) + + error_str = "`interval_size` must be positive" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=1, + interval_size=0) + + error_str = "`interval` must be non-negative" + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-1) + @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-2) + end + + @testset "constructor" begin + callback = SteadyStateReachedCallback(interval=1, abstol=1, reltol=1) + steady_state_cb = callback.affect! + + @test steady_state_cb.abstol === 1.0 + @test steady_state_cb.reltol === 1.0 + @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(interval=Int32(2)) + steady_state_cb = callback.affect! + @test steady_state_cb.interval === 2 + + push!(steady_state_cb.previous_ekin, 1.0) + semi = (; integrate_tlsph=Ref(true), update_callback_used=Ref(false)) + integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) + callback.initialize(callback, nothing, 0.0, integrator) + + @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(dt=0.1) + steady_state_cb = callback.affect!.affect! + push!(steady_state_cb.previous_ekin, 1.0) + integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) + TrixiParticles.initialize_steady_state_callback!(callback.affect!, nothing, 0.0, + integrator) + + @test steady_state_cb.previous_ekin == [Inf] + + callback = SteadyStateReachedCallback(dt=0.1, interval=0) + @test callback.affect!.affect!.interval === 0.1 + end + + @testset "condition interval" begin + function mock_integrator(naccept) + return (; stats=(; naccept), t=0.0, sol=(; prob=(; tspan=(0.0, 1.0))), + opts=(; tstops=[1.0], maxiters=100), iter=naccept) + end + + callback = SteadyStateReachedCallback(interval=1).affect! + @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(1)) + + callback = SteadyStateReachedCallback(interval=10).affect! + @test !TrixiParticles.condition_steady_state_interval(callback, mock_integrator(9)) + @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(10)) end end diff --git a/test/examples/examples.jl b/test/examples/examples.jl index 03d1e5e433..2bb8c871e6 100644 --- a/test/examples/examples.jl +++ b/test/examples/examples.jl @@ -572,8 +572,8 @@ joinpath(examples_dir(), "preprocessing", "packing_2d.jl"), particle_spacing=0.4) - expected_coordinates = [-0.540548 -0.189943 0.191664 0.542741 -0.629391 -0.196159 0.197725 0.63081 -0.629447 -0.196158 0.19779 0.631121 -0.540483 -0.190015 0.191345 0.540433; - -0.541127 -0.630201 -0.630119 -0.539294 -0.190697 -0.196942 -0.196916 -0.190324 0.190875 0.197074 0.196955 0.190973 0.541206 0.630323 0.630178 0.541314] + expected_coordinates = [-0.540548 -0.189649 0.19137 0.542738 -0.63052 -0.196161 0.197719 0.631925 -0.630577 -0.19616 0.197787 0.632248 -0.540485 -0.189726 0.191038 0.540435; + -0.541127 -0.63133 -0.631244 -0.539296 -0.190402 -0.196943 -0.196917 -0.190013 0.190581 0.197074 0.19695 0.190685 0.541204 0.631451 0.6313 0.541311] @test isapprox(packed_ic.coordinates, expected_coordinates, atol=1e-5) end