Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/preprocessing/packing_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ tspan = (0, 10.0)
ode = semidiscretize(semi, tspan)

# Use this callback to stop the simulation when it is sufficiently close to a steady state
steady_state = SteadyStateReachedCallback(; interval=10, interval_size=200,
steady_state = SteadyStateReachedCallback(; interval=10, interval_size=20,
abstol=1.0e-7, reltol=1.0e-6)

info_callback = InfoCallback(interval=50)
Expand Down
104 changes: 81 additions & 23 deletions src/callbacks/steady_state_reached.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ where `ekin` is the total kinetic energy of the simulation.

# Keywords
- `interval=0`: Check steady state condition every `interval` time steps.
A value of `0` disables step-interval checks.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A value of `0` disables step-interval checks.
Use either `interval` or `dt`.

- `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms
of integration time by adding additional `tstops`
(note that this may change the solution).
- `interval_size`: The interval in which the change of the kinetic energy is considered.
`interval_size` is a (integer) multiple of `interval` or `dt`.
- Either `interval` or `dt` must be set to something larger than 0.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Either `interval` or `dt` must be set to something larger than 0.
Use either `interval` or `dt`.

- `interval_size`: The number of callback evaluations over which the change of the
kinetic energy is considered.
- `abstol`: Absolute tolerance.
- `reltol`: Relative tolerance.
"""
Expand All @@ -26,51 +28,107 @@ end

function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0,
interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6)
ELTYPE = eltype(abstol)
abstol, reltol = promote(abstol, reltol)
if interval < 0
throw(ArgumentError("`interval` must be non-negative"))
end

if dt > 0 && interval > 0
throw(ArgumentError("setting both `interval` and `dt` is not supported"))
end

if dt > 0
interval = convert(ELTYPE, dt)
if dt <= 0 && interval == 0
throw(ArgumentError("either `interval` or `dt` must be set to a positive value"))
end

if interval_size <= 0
throw(ArgumentError("`interval_size` must be positive"))
end

abstol, reltol = float.(promote(abstol, reltol))
ELTYPE = typeof(abstol)

interval_ = if dt > 0
convert(ELTYPE, dt)
else
Int(interval)
end

steady_state_callback = SteadyStateReachedCallback(interval, abstol, reltol,
steady_state_callback = SteadyStateReachedCallback(interval_, abstol, reltol,
[convert(ELTYPE, Inf)],
interval_size)

if dt > 0
return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false),
final_affect=true)
return PeriodicCallback(steady_state_callback, dt,
initialize=(initialize_steady_state_callback!),
save_positions=(false, false))
else
return DiscreteCallback(steady_state_callback, steady_state_callback,
save_positions=(false, false))
save_positions=(false, false),
initialize=(initialize_steady_state_callback!))
end
end

function initialize_steady_state_callback!(cb, u, t, integrator)
# The `SteadyStateReachedCallback` is either `cb.affect!` (with `DiscreteCallback`)
# or `cb.affect!.affect!` (with `PeriodicCallback`).
# Let recursive dispatch handle this.
initialize_steady_state_callback!(cb.affect!, u, t, integrator)
end

function initialize_steady_state_callback!(cb::SteadyStateReachedCallback, u, t, integrator)
semi = integrator.p.semi
set_callbacks_used!(semi, integrator)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?


empty!(cb.previous_ekin)
push!(cb.previous_ekin, convert(eltype(cb.previous_ekin), Inf))

return nothing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return nothing
return cb

end

# `affect!` (`PeriodicCallback`)
function (cb::SteadyStateReachedCallback)(integrator)
steady_state_condition!(cb, integrator) || return nothing
if !steady_state_condition!(cb, integrator)
u_modified!(integrator, false)
Comment on lines +90 to +91
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if !steady_state_condition!(cb, integrator)
u_modified!(integrator, false)
u_modified!(integrator, false)
if !steady_state_condition!(cb, integrator)

Do we even need u_modified! in the condition? Have you checked this?

return nothing
end

print_summary(integrator)

terminate!(integrator)

u_modified!(integrator, false)

Comment on lines +99 to +100
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
u_modified!(integrator, false)

return nothing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return nothing
return cb

end

# `affect!` (`DiscreteCallback`)
function (cb::SteadyStateReachedCallback{Int})(integrator)
print_summary(integrator)

terminate!(integrator)

u_modified!(integrator, false)

return nothing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return nothing
return cb

end

# `condition` (`DiscreteCallback`)
function (steady_state_callback::SteadyStateReachedCallback{Int})(vu_ode, t, integrator)
condition_steady_state_interval(steady_state_callback, integrator) || return false
Comment thread
svchb marked this conversation as resolved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
condition_steady_state_interval(steady_state_callback, integrator) || return false
if condition_integrator_interval(integrator, cb.interval; save_final_solution=false)
return false
end

No need for an extra function.


return steady_state_condition!(steady_state_callback, integrator)
end

function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrator)
return steady_state_condition!(steady_state_callback, integrator)
end

@inline function condition_steady_state_interval(cb::SteadyStateReachedCallback{Int},
integrator)
return condition_integrator_interval(integrator, cb.interval;
save_final_solution=false)
end

@inline function steady_state_condition!(cb, integrator)
(; abstol, reltol, previous_ekin, interval_size) = cb

Expand Down Expand Up @@ -113,8 +171,8 @@ function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallb

cb_ = cb.affect!

print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol,
")")
print(io, "SteadyStateReachedCallback(interval=", cb_.interval,
", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")")
end

function Base.show(io::IO,
Expand All @@ -124,8 +182,8 @@ function Base.show(io::IO,

cb_ = cb.affect!.affect!

print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol,
")")
print(io, "SteadyStateReachedCallback(dt=", cb_.interval,
", abstol=", cb_.abstol, ", reltol=", cb_.reltol, ")")
end

function Base.show(io::IO, ::MIME"text/plain",
Expand All @@ -137,10 +195,10 @@ function Base.show(io::IO, ::MIME"text/plain",
else
cb_ = cb.affect!

setup = ["absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol,
"interval" => cb_.interval,
"interval size" => cb_.interval_size]
setup = ["interval" => cb_.interval,
"interval size" => cb_.interval_size,
"absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol]
summary_box(io, "SteadyStateReachedCallback", setup)
end
end
Expand All @@ -155,10 +213,10 @@ function Base.show(io::IO, ::MIME"text/plain",
else
cb_ = cb.affect!.affect!

setup = ["absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol,
"interval" => cb_.interval,
"interval_size" => cb_.interval_size]
setup = ["dt" => cb_.interval,
"interval size" => cb_.interval_size,
"absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol]
summary_box(io, "SteadyStateReachedCallback", setup)
end
end
80 changes: 71 additions & 9 deletions test/callbacks/steady_state_reached.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,113 @@
@testset verbose=true "SteadyStateReachedCallback" begin
@testset verbose=true "show" begin
# Default
callback0 = SteadyStateReachedCallback()
callback0 = SteadyStateReachedCallback(interval=1)

show_compact = "SteadyStateReachedCallback(abstol=1.0e-8, reltol=1.0e-6)"
show_compact = "SteadyStateReachedCallback(interval=1, abstol=1.0e-8, reltol=1.0e-6)"
@test repr(callback0) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ interval: ……………………………………………………… 1.0 │
│ interval size: ………………………………………… 10.0 │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 0.0 │
│ interval size: ………………………………………… 10.0 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback0) == show_box

callback1 = SteadyStateReachedCallback(interval=11)

show_compact = "SteadyStateReachedCallback(interval=11, abstol=1.0e-8, reltol=1.0e-6)"
@test repr(callback1) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 11.0 │
│ interval size: ………………………………………… 10.0 │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback1) == show_box

callback2 = SteadyStateReachedCallback(dt=1.2)

show_compact = "SteadyStateReachedCallback(dt=1.2, abstol=1.0e-8, reltol=1.0e-6)"
@test repr(callback2) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ dt: ……………………………………………………………………… 1.2 │
│ interval size: ………………………………………… 10.0 │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 1.2 │
│ interval_size: ………………………………………… 10.0 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback2) == show_box
end

@testset "Illegal Input" begin
error_str = "either `interval` or `dt` must be set to a positive value"
@test_throws ArgumentError(error_str) SteadyStateReachedCallback()

error_str = "setting both `interval` and `dt` is not supported"
@test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1)

error_str = "`interval_size` must be positive"
@test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=1,
interval_size=0)

error_str = "`interval` must be non-negative"
@test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-1)
@test_throws ArgumentError(error_str) SteadyStateReachedCallback(interval=-2)
end

@testset "constructor" begin
callback = SteadyStateReachedCallback(interval=1, abstol=1, reltol=1)
steady_state_cb = callback.affect!

@test steady_state_cb.abstol === 1.0
@test steady_state_cb.reltol === 1.0
@test steady_state_cb.previous_ekin == [Inf]

callback = SteadyStateReachedCallback(interval=Int32(2))
steady_state_cb = callback.affect!
@test steady_state_cb.interval === 2

push!(steady_state_cb.previous_ekin, 1.0)
semi = (; integrate_tlsph=Ref(true), update_callback_used=Ref(false))
integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback])))
callback.initialize(callback, nothing, 0.0, integrator)

@test steady_state_cb.previous_ekin == [Inf]

callback = SteadyStateReachedCallback(dt=0.1)
steady_state_cb = callback.affect!.affect!
push!(steady_state_cb.previous_ekin, 1.0)
integrator = (; p=(; semi), opts=(; callback=(; discrete_callbacks=[callback])))
TrixiParticles.initialize_steady_state_callback!(callback.affect!, nothing, 0.0,
integrator)

@test steady_state_cb.previous_ekin == [Inf]

callback = SteadyStateReachedCallback(dt=0.1, interval=0)
@test callback.affect!.affect!.interval === 0.1
end

@testset "condition interval" begin
function mock_integrator(naccept)
return (; stats=(; naccept), t=0.0, sol=(; prob=(; tspan=(0.0, 1.0))),
opts=(; tstops=[1.0], maxiters=100), iter=naccept)
end

callback = SteadyStateReachedCallback(interval=1).affect!
@test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(1))

callback = SteadyStateReachedCallback(interval=10).affect!
@test !TrixiParticles.condition_steady_state_interval(callback, mock_integrator(9))
@test TrixiParticles.condition_steady_state_interval(callback, mock_integrator(10))
end
end
4 changes: 2 additions & 2 deletions test/examples/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@
joinpath(examples_dir(), "preprocessing",
"packing_2d.jl"),
particle_spacing=0.4)
expected_coordinates = [-0.540548 -0.189943 0.191664 0.542741 -0.629391 -0.196159 0.197725 0.63081 -0.629447 -0.196158 0.19779 0.631121 -0.540483 -0.190015 0.191345 0.540433;
-0.541127 -0.630201 -0.630119 -0.539294 -0.190697 -0.196942 -0.196916 -0.190324 0.190875 0.197074 0.196955 0.190973 0.541206 0.630323 0.630178 0.541314]
expected_coordinates = [-0.540548 -0.189649 0.19137 0.542738 -0.63052 -0.196161 0.197719 0.631925 -0.630577 -0.19616 0.197787 0.632248 -0.540485 -0.189726 0.191038 0.540435;
-0.541127 -0.63133 -0.631244 -0.539296 -0.190402 -0.196943 -0.196917 -0.190013 0.190581 0.197074 0.19695 0.190685 0.541204 0.631451 0.6313 0.541311]

@test isapprox(packed_ic.coordinates, expected_coordinates, atol=1e-5)
end
Expand Down
Loading