-
Notifications
You must be signed in to change notification settings - Fork 20
[codex] Fix steady state callback interval handling #1180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ec1bfc9
056e498
1743325
bed80e5
f708c5c
8db26fb
4a2beee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,11 +8,13 @@ where `ekin` is the total kinetic energy of the simulation. | |||||||||||
|
|
||||||||||||
| # Keywords | ||||||||||||
| - `interval=0`: Check steady state condition every `interval` time steps. | ||||||||||||
| A value of `0` disables step-interval checks. | ||||||||||||
| - `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms | ||||||||||||
| of integration time by adding additional `tstops` | ||||||||||||
| (note that this may change the solution). | ||||||||||||
| - `interval_size`: The interval in which the change of the kinetic energy is considered. | ||||||||||||
| `interval_size` is a (integer) multiple of `interval` or `dt`. | ||||||||||||
| - Either `interval` or `dt` must be set to something larger than 0. | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| - `interval_size`: The number of callback evaluations over which the change of the | ||||||||||||
| kinetic energy is considered. | ||||||||||||
| - `abstol`: Absolute tolerance. | ||||||||||||
| - `reltol`: Relative tolerance. | ||||||||||||
| """ | ||||||||||||
|
|
@@ -26,51 +28,107 @@ end | |||||||||||
|
|
||||||||||||
| function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0, | ||||||||||||
| interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6) | ||||||||||||
| ELTYPE = eltype(abstol) | ||||||||||||
| abstol, reltol = promote(abstol, reltol) | ||||||||||||
| if interval < 0 | ||||||||||||
| throw(ArgumentError("`interval` must be non-negative")) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| if dt > 0 && interval > 0 | ||||||||||||
| throw(ArgumentError("setting both `interval` and `dt` is not supported")) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| if dt > 0 | ||||||||||||
| interval = convert(ELTYPE, dt) | ||||||||||||
| if dt <= 0 && interval == 0 | ||||||||||||
| throw(ArgumentError("either `interval` or `dt` must be set to a positive value")) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| if interval_size <= 0 | ||||||||||||
| throw(ArgumentError("`interval_size` must be positive")) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| abstol, reltol = float.(promote(abstol, reltol)) | ||||||||||||
| ELTYPE = typeof(abstol) | ||||||||||||
|
|
||||||||||||
| interval_ = if dt > 0 | ||||||||||||
| convert(ELTYPE, dt) | ||||||||||||
| else | ||||||||||||
| Int(interval) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| steady_state_callback = SteadyStateReachedCallback(interval, abstol, reltol, | ||||||||||||
| steady_state_callback = SteadyStateReachedCallback(interval_, abstol, reltol, | ||||||||||||
| [convert(ELTYPE, Inf)], | ||||||||||||
| interval_size) | ||||||||||||
|
|
||||||||||||
| if dt > 0 | ||||||||||||
| return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false), | ||||||||||||
| final_affect=true) | ||||||||||||
| return PeriodicCallback(steady_state_callback, dt, | ||||||||||||
| initialize=(initialize_steady_state_callback!), | ||||||||||||
| save_positions=(false, false)) | ||||||||||||
| else | ||||||||||||
| return DiscreteCallback(steady_state_callback, steady_state_callback, | ||||||||||||
| save_positions=(false, false)) | ||||||||||||
| save_positions=(false, false), | ||||||||||||
| initialize=(initialize_steady_state_callback!)) | ||||||||||||
| end | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| function initialize_steady_state_callback!(cb, u, t, integrator) | ||||||||||||
| # The `SteadyStateReachedCallback` is either `cb.affect!` (with `DiscreteCallback`) | ||||||||||||
| # or `cb.affect!.affect!` (with `PeriodicCallback`). | ||||||||||||
| # Let recursive dispatch handle this. | ||||||||||||
| initialize_steady_state_callback!(cb.affect!, u, t, integrator) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| function initialize_steady_state_callback!(cb::SteadyStateReachedCallback, u, t, integrator) | ||||||||||||
| semi = integrator.p.semi | ||||||||||||
| set_callbacks_used!(semi, integrator) | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? |
||||||||||||
|
|
||||||||||||
| empty!(cb.previous_ekin) | ||||||||||||
| push!(cb.previous_ekin, convert(eltype(cb.previous_ekin), Inf)) | ||||||||||||
|
|
||||||||||||
| return nothing | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| # `affect!` (`PeriodicCallback`) | ||||||||||||
| function (cb::SteadyStateReachedCallback)(integrator) | ||||||||||||
| steady_state_condition!(cb, integrator) || return nothing | ||||||||||||
| if !steady_state_condition!(cb, integrator) | ||||||||||||
| u_modified!(integrator, false) | ||||||||||||
|
Comment on lines
+90
to
+91
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Do we even need |
||||||||||||
| return nothing | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| print_summary(integrator) | ||||||||||||
|
|
||||||||||||
| terminate!(integrator) | ||||||||||||
|
|
||||||||||||
| u_modified!(integrator, false) | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+99
to
+100
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| return nothing | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| # `affect!` (`DiscreteCallback`) | ||||||||||||
| function (cb::SteadyStateReachedCallback{Int})(integrator) | ||||||||||||
| print_summary(integrator) | ||||||||||||
|
|
||||||||||||
| terminate!(integrator) | ||||||||||||
|
|
||||||||||||
| u_modified!(integrator, false) | ||||||||||||
|
|
||||||||||||
| return nothing | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| # `condition` (`DiscreteCallback`) | ||||||||||||
| function (steady_state_callback::SteadyStateReachedCallback{Int})(vu_ode, t, integrator) | ||||||||||||
| condition_steady_state_interval(steady_state_callback, integrator) || return false | ||||||||||||
|
svchb marked this conversation as resolved.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
No need for an extra function. |
||||||||||||
|
|
||||||||||||
| return steady_state_condition!(steady_state_callback, integrator) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrator) | ||||||||||||
| return steady_state_condition!(steady_state_callback, integrator) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| @inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int}, | ||||||||||||
| integrator) | ||||||||||||
| return condition_integrator_interval(integrator, cb.interval; | ||||||||||||
| save_final_solution=false) | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| @inline function steady_state_condition!(cb, integrator) | ||||||||||||
| (; abstol, reltol, previous_ekin, interval_size) = cb | ||||||||||||
|
|
||||||||||||
|
|
@@ -113,8 +171,8 @@ function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallb | |||||||||||
|
|
||||||||||||
| cb_ = cb.affect! | ||||||||||||
|
|
||||||||||||
| print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol, | ||||||||||||
| ")") | ||||||||||||
| print(io, "SteadyStateReachedCallback(interval=", cb_.interval, | ||||||||||||
| ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| function Base.show(io::IO, | ||||||||||||
|
|
@@ -124,8 +182,8 @@ function Base.show(io::IO, | |||||||||||
|
|
||||||||||||
| cb_ = cb.affect!.affect! | ||||||||||||
|
|
||||||||||||
| print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol, | ||||||||||||
| ")") | ||||||||||||
| print(io, "SteadyStateReachedCallback(dt=", cb_.interval, | ||||||||||||
| ", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")") | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| function Base.show(io::IO, ::MIME"text/plain", | ||||||||||||
|
|
@@ -137,10 +195,10 @@ function Base.show(io::IO, ::MIME"text/plain", | |||||||||||
| else | ||||||||||||
| cb_ = cb.affect! | ||||||||||||
|
|
||||||||||||
| setup = ["absolute tolerance" => cb_.abstol, | ||||||||||||
| "relative tolerance" => cb_.reltol, | ||||||||||||
| "interval" => cb_.interval, | ||||||||||||
| "interval size" => cb_.interval_size] | ||||||||||||
| setup = ["interval" => cb_.interval, | ||||||||||||
| "interval size" => cb_.interval_size, | ||||||||||||
| "absolute tolerance" => cb_.abstol, | ||||||||||||
| "relative tolerance" => cb_.reltol] | ||||||||||||
| summary_box(io, "SteadyStateReachedCallback", setup) | ||||||||||||
| end | ||||||||||||
| end | ||||||||||||
|
|
@@ -155,10 +213,10 @@ function Base.show(io::IO, ::MIME"text/plain", | |||||||||||
| else | ||||||||||||
| cb_ = cb.affect!.affect! | ||||||||||||
|
|
||||||||||||
| setup = ["absolute tolerance" => cb_.abstol, | ||||||||||||
| "relative tolerance" => cb_.reltol, | ||||||||||||
| "interval" => cb_.interval, | ||||||||||||
| "interval_size" => cb_.interval_size] | ||||||||||||
| setup = ["dt" => cb_.interval, | ||||||||||||
| "interval size" => cb_.interval_size, | ||||||||||||
| "absolute tolerance" => cb_.abstol, | ||||||||||||
| "relative tolerance" => cb_.reltol] | ||||||||||||
| summary_box(io, "SteadyStateReachedCallback", setup) | ||||||||||||
| end | ||||||||||||
| end | ||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,51 +1,113 @@ | ||
| @testset verbose=true "SteadyStateReachedCallback" begin | ||
| @testset verbose=true "show" begin | ||
| # Default | ||
| callback0 = SteadyStateReachedCallback() | ||
| callback0 = SteadyStateReachedCallback(interval=1) | ||
|
|
||
| show_compact = "SteadyStateReachedCallback(abstol=1.0e-8, reltol=1.0e-6)" | ||
| show_compact = "SteadyStateReachedCallback(interval=1, abstol=1.0e-8, reltol=1.0e-6)" | ||
| @test repr(callback0) == show_compact | ||
|
|
||
| show_box = """ | ||
| ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ | ||
| │ SteadyStateReachedCallback │ | ||
| │ ══════════════════════════ │ | ||
| │ interval: ……………………………………………………… 1.0 │ | ||
| │ interval size: ………………………………………… 10.0 │ | ||
| │ absolute tolerance: …………………………… 1.0e-8 │ | ||
| │ relative tolerance: …………………………… 1.0e-6 │ | ||
| │ interval: ……………………………………………………… 0.0 │ | ||
| │ interval size: ………………………………………… 10.0 │ | ||
| └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" | ||
| @test repr("text/plain", callback0) == show_box | ||
|
|
||
| callback1 = SteadyStateReachedCallback(interval=11) | ||
|
|
||
| show_compact = "SteadyStateReachedCallback(interval=11, abstol=1.0e-8, reltol=1.0e-6)" | ||
| @test repr(callback1) == show_compact | ||
|
|
||
| show_box = """ | ||
| ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ | ||
| │ SteadyStateReachedCallback │ | ||
| │ ══════════════════════════ │ | ||
| │ absolute tolerance: …………………………… 1.0e-8 │ | ||
| │ relative tolerance: …………………………… 1.0e-6 │ | ||
| │ interval: ……………………………………………………… 11.0 │ | ||
| │ interval size: ………………………………………… 10.0 │ | ||
| │ absolute tolerance: …………………………… 1.0e-8 │ | ||
| │ relative tolerance: …………………………… 1.0e-6 │ | ||
| └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" | ||
| @test repr("text/plain", callback1) == show_box | ||
|
|
||
| callback2 = SteadyStateReachedCallback(dt=1.2) | ||
|
|
||
| show_compact = "SteadyStateReachedCallback(dt=1.2, abstol=1.0e-8, reltol=1.0e-6)" | ||
| @test repr(callback2) == show_compact | ||
|
|
||
| show_box = """ | ||
| ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ | ||
| │ SteadyStateReachedCallback │ | ||
| │ ══════════════════════════ │ | ||
| │ dt: ……………………………………………………………………… 1.2 │ | ||
| │ interval size: ………………………………………… 10.0 │ | ||
| │ absolute tolerance: …………………………… 1.0e-8 │ | ||
| │ relative tolerance: …………………………… 1.0e-6 │ | ||
| │ interval: ……………………………………………………… 1.2 │ | ||
| │ interval_size: ………………………………………… 10.0 │ | ||
| └──────────────────────────────────────────────────────────────────────────────────────────────────┘""" | ||
| @test repr("text/plain", callback2) == show_box | ||
| end | ||
|
|
||
| @testset "Illegal Input" begin | ||
| error_str = "either `interval` or `dt` must be set to a positive value" | ||
| @test_throws ArgumentError(error_str) SteadyStateReachedCallback() | ||
|
|
||
| error_str = "setting both `interval` and `dt` is not supported" | ||
| @test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1) | ||
|
|
||
| error_str = "`interval_size` must be positive" | ||
| @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=1, | ||
| interval_size=0) | ||
|
|
||
| error_str = "`interval` must be non-negative" | ||
| @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-1) | ||
| @test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-2) | ||
| end | ||
|
|
||
| @testset "constructor" begin | ||
| callback = SteadyStateReachedCallback(interval=1, abstol=1, reltol=1) | ||
| steady_state_cb = callback.affect! | ||
|
|
||
| @test steady_state_cb.abstol === 1.0 | ||
| @test steady_state_cb.reltol === 1.0 | ||
| @test steady_state_cb.previous_ekin == [Inf] | ||
|
|
||
| callback = SteadyStateReachedCallback(interval=Int32(2)) | ||
| steady_state_cb = callback.affect! | ||
| @test steady_state_cb.interval === 2 | ||
|
|
||
| push!(steady_state_cb.previous_ekin, 1.0) | ||
| semi = (; integrate_tlsph=Ref(true), update_callback_used=Ref(false)) | ||
| integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) | ||
| callback.initialize(callback, nothing, 0.0, integrator) | ||
|
|
||
| @test steady_state_cb.previous_ekin == [Inf] | ||
|
|
||
| callback = SteadyStateReachedCallback(dt=0.1) | ||
| steady_state_cb = callback.affect!.affect! | ||
| push!(steady_state_cb.previous_ekin, 1.0) | ||
| integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback]))) | ||
| TrixiParticles.initialize_steady_state_callback!(callback.affect!, nothing, 0.0, | ||
| integrator) | ||
|
|
||
| @test steady_state_cb.previous_ekin == [Inf] | ||
|
|
||
| callback = SteadyStateReachedCallback(dt=0.1, interval=0) | ||
| @test callback.affect!.affect!.interval === 0.1 | ||
| end | ||
|
|
||
| @testset "condition interval" begin | ||
| function mock_integrator(naccept) | ||
| return (; stats=(; naccept), t=0.0, sol=(; prob=(; tspan=(0.0, 1.0))), | ||
| opts=(; tstops=[1.0], maxiters=100), iter=naccept) | ||
| end | ||
|
|
||
| callback = SteadyStateReachedCallback(interval=1).affect! | ||
| @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(1)) | ||
|
|
||
| callback = SteadyStateReachedCallback(interval=10).affect! | ||
| @test !TrixiParticles.condition_steady_state_interval(callback, mock_integrator(9)) | ||
| @test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(10)) | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.