From 65979bd134cd88a60577d8bd0d8c80ecb022e298 Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Fri, 23 May 2025 15:44:24 +0200 Subject: [PATCH 01/11] added IMEX --- examples/kdv_1d/kdv_1d_IMEX.jl | 46 ++++++++++++++++ src/DispersiveShallowWater.jl | 2 +- src/equations/kdv_1d.jl | 96 +++++++++++++++++++++++++++++++++- src/semidiscretization.jl | 36 +++++++++++-- 4 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 examples/kdv_1d/kdv_1d_IMEX.jl diff --git a/examples/kdv_1d/kdv_1d_IMEX.jl b/examples/kdv_1d/kdv_1d_IMEX.jl new file mode 100644 index 000000000..fb792ed31 --- /dev/null +++ b/examples/kdv_1d/kdv_1d_IMEX.jl @@ -0,0 +1,46 @@ +using OrdinaryDiffEqTsit5 +using DispersiveShallowWater +using SummationByPartsOperators: upwind_operators, periodic_derivative_operator + +############################################################################### +# Semidiscretization of the KdV equation + +equations = KdVEquation1D(gravity = 9.81, D = 1.0) +initial_condition = initial_condition_convergence_test +boundary_conditions = boundary_condition_periodic + +# create homogeneous mesh +coordinates_min = -50.0 +coordinates_max = 50.0 +N = 512 +mesh = Mesh1D(coordinates_min, coordinates_max, N) + +# Create solver with periodic SBP operators of accuracy order 3, +# which results in a 4th order accurate semi discretizations. +# We can set the accuracy order of the upwind operators to 3 since +# we only use central versions/combinations of the upwind operators. +D1_upwind = upwind_operators(periodic_derivative_operator; + derivative_order = 1, accuracy_order = 3, + xmin = xmin(mesh), xmax = xmax(mesh), + N = nnodes(mesh)) +solver = Solver(D1_upwind) + +semi = Semidiscretization(mesh, equations, initial_condition, solver, + boundary_conditions = boundary_conditions) + +tspan = (0.0, 5.0) +ode = semidiscretize(semi, tspan, no_splitform = false) + +summary_callback = SummaryCallback() +analysis_callback = AnalysisCallback(semi; interval = 100, + extra_analysis_errors = (:conservation_error,), + extra_analysis_integrals = (waterheight_total, + waterheight)) +callbacks = CallbackSet(analysis_callback, summary_callback) +saveat = range(tspan..., length = 100) + +# alg = Rodas5() # not working because of https://github.com/SciML/OrdinaryDiffEq.jl/issues/2719 +# alg = KenCarp4() # would need to add OrdinaryDiffEqSDIRK - can to that, but I dont want to bloat DSW.jl with package s +alg = Tsit5() +sol = solve(ode, alg, abstol = 1e-7, reltol = 1e-7, + save_everystep = false, callback = callbacks, saveat = saveat) diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index 562147e74..fcddc10f2 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -30,7 +30,7 @@ using RecursiveArrayTools: ArrayPartition using Reexport: @reexport using Roots: AlefeldPotraShi, find_zero -using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution +using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution, SplitFunction import SciMLBase: u_modified! @reexport using StaticArrays: SVector diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 59c9f2a85..eb15142f5 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -158,7 +158,7 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, # set D1 for hyperbolic terms D1 = solver.D1 end - + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx @.. deta = -1 / 6 * c_0 * DD * tmp_1 end @@ -181,3 +181,97 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, return nothing end + +#= +function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache) + # define locally bc to not use a global variable + bc = boundary_condition_periodic + (; tmp1) = cache + rhs_1!(dq, q, t, mesh, equations, initial_condition, + bc, source_terms, solver, cache) + + deta, = dq.x + tmp1 .= deta + + # rhs_split_2! needs to be defined to set "deta = ..." + # and not just "deta += ..." in order to work with SplitODEProblem + # This means one either has to store the result in a temporary variable + # or have a function rhs! does not call rhs_split_1! and rhs_split_2!. + # and currently it does not work with ForwardDiff + rhs_split_2!(dq, q, t, mesh, equations, initial_condition, + bc, source_terms, solver, cache) + + deta .+= tmp1 + return nothing +end +# =# + +function rhs_split_1!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache) + eta, = q.x + deta, = dq.x + + (; c_0, DD) = cache + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + + @trixi_timeit timer() "third-order derivatives" begin + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + # eta_xxx = Dp * Dc * Dm * eta + mul!(tmp_1, solver.D1.minus, eta) + mul!(tmp_2, solver.D1.central, tmp_1) + mul!(tmp_1, solver.D1.plus, tmp_2) + else + # eta_xxx = D3 * eta + mul!(tmp_1, solver.D3, eta) + end + + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx + + @.. deta = -1 / 6 * c_0 * DD * tmp_1 + end + + return nothing +end + +function rhs_split_2!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache) + eta, = q.x + deta, = dq.x + + (; c_0, c_1) = cache + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + D1 = solver.D1.central + else + D1 = solver.D1 + end + + @trixi_timeit timer() "hyperbolic" begin + # eta2 = eta^2 + @.. tmp_1 = eta^2 + + # eta2_x = D1 * eta2 + mul!(tmp_2, D1, tmp_1) + + # eta_x = D1 * eta + mul!(tmp_1, D1, eta) + + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) + end + + @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations, + solver) + + return nothing +end diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index 30deba8b9..7bf15b1d2 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -186,6 +186,28 @@ function rhs!(dq, q, semi::Semidiscretization, t) return nothing end +function rhs_split_1!(dq, q, semi::Semidiscretization, t) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + @trixi_timeit timer() "rhs_split_1!" rhs_split_1!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, + solver, + cache) + return nothing +end + +function rhs_split_2!(dq, q, semi::Semidiscretization, t) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + @trixi_timeit timer() "rhs_split_2!" rhs_split_2!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, + solver, + cache) + return nothing +end + function compute_coefficients(func, t, semi::Semidiscretization) @unpack mesh, equations, solver = semi q = allocate_coefficients(mesh_equations_solver(semi)...) @@ -214,16 +236,21 @@ function check_bathymetry(equations::AbstractShallowWaterEquations, q0) end """ - semidiscretize(semi::Semidiscretization, tspan) + semidiscretize(semi::Semidiscretization, tspan; no_splitform = true) Wrap the semidiscretization `semi` as an ODE problem in the time interval `tspan` that can be passed to `solve` from the [SciML ecosystem](https://diffeq.sciml.ai/latest/). """ -function semidiscretize(semi::Semidiscretization, tspan) +function semidiscretize(semi::Semidiscretization, tspan; no_splitform = true) q0 = compute_coefficients(semi.initial_condition, first(tspan), semi) check_bathymetry(semi.equations, q0) iip = true # is-inplace, i.e., we modify a vector when calling rhs! - return ODEProblem{iip}(rhs!, q0, tspan, semi) + if no_splitform + ode = ODEProblem{iip}(rhs!, q0, tspan, semi) + else + ode = ODEProblem{iip}(SplitFunction(rhs_split_1!, rhs_split_2!), q0, tspan, semi) + end + return ode end """ @@ -241,7 +268,8 @@ of the semidiscretization `semi` at the state `q0`. function jacobian(semi::Semidiscretization; t = 0.0, q0 = compute_coefficients(semi.initial_condition, t, semi)) - J = ForwardDiff.jacobian(similar(q0), q0) do dq, q + @unpack tmp_partitioned = semi.cache + J = ForwardDiff.jacobian(tmp_partitioned, q0) do dq, q DispersiveShallowWater.rhs!(dq, q, semi, t) end return J From 62f42fec7e816b95613f2c83b4dfabec4f4d9381 Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Fri, 23 May 2025 15:44:55 +0200 Subject: [PATCH 02/11] deleted rhs using split --- src/equations/kdv_1d.jl | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index eb15142f5..11cd2c624 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -158,7 +158,7 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, # set D1 for hyperbolic terms D1 = solver.D1 end - + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx @.. deta = -1 / 6 * c_0 * DD * tmp_1 end @@ -182,31 +182,6 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, return nothing end -#= -function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, - ::BoundaryConditionPeriodic, source_terms, solver, cache) - # define locally bc to not use a global variable - bc = boundary_condition_periodic - (; tmp1) = cache - rhs_1!(dq, q, t, mesh, equations, initial_condition, - bc, source_terms, solver, cache) - - deta, = dq.x - tmp1 .= deta - - # rhs_split_2! needs to be defined to set "deta = ..." - # and not just "deta += ..." in order to work with SplitODEProblem - # This means one either has to store the result in a temporary variable - # or have a function rhs! does not call rhs_split_1! and rhs_split_2!. - # and currently it does not work with ForwardDiff - rhs_split_2!(dq, q, t, mesh, equations, initial_condition, - bc, source_terms, solver, cache) - - deta .+= tmp1 - return nothing -end -# =# - function rhs_split_1!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, ::BoundaryConditionPeriodic, source_terms, solver, cache) eta, = q.x From d9806af6c3c034d77829c7b138263a8d6defd99e Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Fri, 23 May 2025 15:54:06 +0200 Subject: [PATCH 03/11] added tests --- test/test_kdv_1d.jl | 11 +++++++++++ test/test_util.jl | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/test/test_kdv_1d.jl b/test/test_kdv_1d.jl index e0c2e4039..866eaa52f 100644 --- a/test/test_kdv_1d.jl +++ b/test/test_kdv_1d.jl @@ -59,3 +59,14 @@ end @test_allocations(semi, sol, allocs=5_000) end + +@testitem "kdv_1d_IMEX" setup=[Setup, KdVEquation1D] begin + @test_trixi_include(joinpath(EXAMPLES_DIR, "kdv_1d_IMEX.jl"), + tspan=(0.0, 5.0), + l2=[0.0007835879713461127], + linf=[0.0005961613764722262], + cons_error=[4.440892098500626e-16], + change_waterheight=-4.440892098500626e-16) + + @test_allocations_splitform(semi, sol, allocs=5_000) +end diff --git a/test/test_util.jl b/test/test_util.jl index 5772aa027..50b7db7a7 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -176,6 +176,24 @@ macro test_allocations(semi, sol, allocs) end end +""" + @test_allocations_splitform(semi, sol, allocs) + +Test that the memory allocations of `DispersiveShallowWater.rhs_split_1!` +and `DispersiveShallowWater.rhs_split_2!` are below `allocs` +(e.g., from type instabilities). +""" +macro test_allocations_splitform(semi, sol, allocs) + quote + t = $sol.t[end] + q = $sol.u[end] + dq = similar(q) + a1 = @allocated DispersiveShallowWater.rhs_split_1!(dq, q, $semi, t) + a2 = @allocated DispersiveShallowWater.rhs_split_2!(dq, q, $semi, t) + @test (a1 + a2) < $allocs + end +end + macro test_nowarn_mod(expr, additional_ignore_content = []) quote add_to_additional_ignore_content = [ From a3c1c50ef215161fb1874fe1779a340060a5ea9f Mon Sep 17 00:00:00 2001 From: Collin Wittenstein <126870995+cwittens@users.noreply.github.com> Date: Fri, 23 May 2025 18:57:36 +0200 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: Joshua Lampert <51029046+JoshuaLampert@users.noreply.github.com> --- src/equations/kdv_1d.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 11cd2c624..26a6afba7 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -182,7 +182,7 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, return nothing end -function rhs_split_1!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, +function rhs_split_stiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, ::BoundaryConditionPeriodic, source_terms, solver, cache) eta, = q.x deta, = dq.x @@ -213,7 +213,7 @@ function rhs_split_1!(dq, q, t, mesh, equations::KdVEquation1D, initial_conditio return nothing end -function rhs_split_2!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, +function rhs_split_nonstiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, ::BoundaryConditionPeriodic, source_terms, solver, cache) eta, = q.x deta, = dq.x From 8ec469f234f65f31499424e66a2fb78fb187813f Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Fri, 23 May 2025 19:08:01 +0200 Subject: [PATCH 05/11] changed alle rhs_split names --- src/semidiscretization.jl | 30 +++++++++++++++++------------- test/test_util.jl | 8 ++++---- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index 7bf15b1d2..b9fdc2e2e 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -186,25 +186,28 @@ function rhs!(dq, q, semi::Semidiscretization, t) return nothing end -function rhs_split_1!(dq, q, semi::Semidiscretization, t) +function rhs_split_stiff!(dq, q, semi::Semidiscretization, t) @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi - @trixi_timeit timer() "rhs_split_1!" rhs_split_1!(dq, q, t, mesh, equations, - initial_condition, - boundary_conditions, source_terms, - solver, - cache) + @trixi_timeit timer() "rhs_split_stiff!" rhs_split_stiff!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, + source_terms, + solver, + cache) return nothing end -function rhs_split_2!(dq, q, semi::Semidiscretization, t) +function rhs_split_nonstiff!(dq, q, semi::Semidiscretization, t) @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi - @trixi_timeit timer() "rhs_split_2!" rhs_split_2!(dq, q, t, mesh, equations, - initial_condition, - boundary_conditions, source_terms, - solver, - cache) + @trixi_timeit timer() "rhs_split_nonstiff!" rhs_split_nonstiff!(dq, q, t, mesh, + equations, + initial_condition, + boundary_conditions, + source_terms, + solver, + cache) return nothing end @@ -248,7 +251,8 @@ function semidiscretize(semi::Semidiscretization, tspan; no_splitform = true) if no_splitform ode = ODEProblem{iip}(rhs!, q0, tspan, semi) else - ode = ODEProblem{iip}(SplitFunction(rhs_split_1!, rhs_split_2!), q0, tspan, semi) + ode = ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, + tspan, semi) end return ode end diff --git a/test/test_util.jl b/test/test_util.jl index 50b7db7a7..2abd17428 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -179,8 +179,8 @@ end """ @test_allocations_splitform(semi, sol, allocs) -Test that the memory allocations of `DispersiveShallowWater.rhs_split_1!` -and `DispersiveShallowWater.rhs_split_2!` are below `allocs` +Test that the memory allocations of `DispersiveShallowWater.rhs_split_stiff!` +and `DispersiveShallowWater.rhs_split_nonstiff!` are below `allocs` (e.g., from type instabilities). """ macro test_allocations_splitform(semi, sol, allocs) @@ -188,8 +188,8 @@ macro test_allocations_splitform(semi, sol, allocs) t = $sol.t[end] q = $sol.u[end] dq = similar(q) - a1 = @allocated DispersiveShallowWater.rhs_split_1!(dq, q, $semi, t) - a2 = @allocated DispersiveShallowWater.rhs_split_2!(dq, q, $semi, t) + a1 = @allocated DispersiveShallowWater.rhs_split_stiff!(dq, q, $semi, t) + a2 = @allocated DispersiveShallowWater.rhs_split_nonstiff!(dq, q, $semi, t) @test (a1 + a2) < $allocs end end From a33f088d029b742e647c270ad8c09827c7d285f1 Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Sat, 24 May 2025 00:01:40 +0200 Subject: [PATCH 06/11] quite a few changes --- examples/kdv_1d/kdv_1d_IMEX.jl | 10 +- examples/kdv_1d/kdv_1d_basic.jl | 2 +- examples/kdv_1d/kdv_1d_fourier.jl | 2 +- examples/kdv_1d/kdv_1d_implicit.jl | 2 +- examples/kdv_1d/kdv_1d_manufactured.jl | 123 ++++++++++++++++++++++- examples/kdv_1d/kdv_1d_narrow_stencil.jl | 2 +- src/DispersiveShallowWater.jl | 2 +- src/equations/equations.jl | 15 +++ src/equations/kdv_1d.jl | 70 +++++++++++++ src/semidiscretization.jl | 65 ++++++++---- test/Project.toml | 2 + test/test_kdv_1d.jl | 8 +- test/test_unit.jl | 1 + 13 files changed, 268 insertions(+), 36 deletions(-) diff --git a/examples/kdv_1d/kdv_1d_IMEX.jl b/examples/kdv_1d/kdv_1d_IMEX.jl index fb792ed31..24c251c77 100644 --- a/examples/kdv_1d/kdv_1d_IMEX.jl +++ b/examples/kdv_1d/kdv_1d_IMEX.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEqTsit5 +using OrdinaryDiffEqSDIRK using DispersiveShallowWater using SummationByPartsOperators: upwind_operators, periodic_derivative_operator @@ -29,7 +29,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan, no_splitform = false) +ode = semidiscretize(semi, tspan) summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, @@ -39,8 +39,8 @@ analysis_callback = AnalysisCallback(semi; interval = 100, callbacks = CallbackSet(analysis_callback, summary_callback) saveat = range(tspan..., length = 100) -# alg = Rodas5() # not working because of https://github.com/SciML/OrdinaryDiffEq.jl/issues/2719 -# alg = KenCarp4() # would need to add OrdinaryDiffEqSDIRK - can to that, but I dont want to bloat DSW.jl with package s -alg = Tsit5() + +alg = KenCarp4() # use an IMEX method sol = solve(ode, alg, abstol = 1e-7, reltol = 1e-7, save_everystep = false, callback = callbacks, saveat = saveat) + diff --git a/examples/kdv_1d/kdv_1d_basic.jl b/examples/kdv_1d/kdv_1d_basic.jl index 6b7de1360..344775be3 100644 --- a/examples/kdv_1d/kdv_1d_basic.jl +++ b/examples/kdv_1d/kdv_1d_basic.jl @@ -29,7 +29,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_fourier.jl b/examples/kdv_1d/kdv_1d_fourier.jl index 25d490158..0b1cf64ac 100644 --- a/examples/kdv_1d/kdv_1d_fourier.jl +++ b/examples/kdv_1d/kdv_1d_fourier.jl @@ -25,7 +25,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_implicit.jl b/examples/kdv_1d/kdv_1d_implicit.jl index ee8c0fc2b..9765f5be6 100644 --- a/examples/kdv_1d/kdv_1d_implicit.jl +++ b/examples/kdv_1d/kdv_1d_implicit.jl @@ -29,7 +29,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_manufactured.jl b/examples/kdv_1d/kdv_1d_manufactured.jl index 81a8ecffa..7474c80dc 100644 --- a/examples/kdv_1d/kdv_1d_manufactured.jl +++ b/examples/kdv_1d/kdv_1d_manufactured.jl @@ -32,7 +32,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, source_terms = source_terms) tspan = (0.0, 1.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, @@ -46,3 +46,124 @@ saveat = range(tspan..., length = 100) alg = Rodas5() sol = solve(ode, alg, abstol = 1e-12, reltol = 1e-12, save_everystep = false, callback = callbacks, saveat = saveat) + +""" +using alg = KenCarp4() I get the following Benchmarks: +For some reason IMEX seems to perform worse - doing more steps +no IMEX: (split_ode = Val{false}()) +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 611ms / 34.2% 40.0MiB / 10.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs! 6.62k 196ms 93.6% 29.6μs 3.94MiB 98.4% 624B + source terms 6.62k 108ms 51.5% 16.3μs 3.94MiB 98.4% 624B + third-order derivatives 6.62k 44.1ms 21.1% 6.66μs 0.00B 0.0% 0.00B + hyperbolic 6.62k 41.1ms 19.7% 6.22μs 0.00B 0.0% 0.00B + ~rhs!~ 6.62k 2.97ms 1.4% 450ns 1.25KiB 0.0% 0.19B +analyze solution 3 13.3ms 6.4% 4.45ms 64.8KiB 1.6% 21.6KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in nonstiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 2.85s / 12.3% 98.6MiB / 5.7% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_nonstiff! 8.93k 169ms 48.3% 18.9μs 5.31MiB 93.8% 624B + source terms 8.93k 160ms 45.6% 17.9μs 5.31MiB 93.8% 624B + hyperbolic 8.93k 5.67ms 1.6% 635ns 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 8.93k 3.44ms 1.0% 386ns 976B 0.0% 0.11B +rhs_split_stiff! 45.4k 103ms 29.4% 2.27μs 672B 0.0% 0.01B + third-order derivatives 45.4k 95.9ms 27.4% 2.11μs 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 45.4k 7.07ms 2.0% 156ns 672B 0.0% 0.01B +analyze solution 15 78.2ms 22.3% 5.21ms 360KiB 6.2% 24.0KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in stiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.89s / 28.2% 88.7MiB / 49.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 72.2k 1.27s 92.0% 17.6μs 43.0MiB 98.9% 624B + source terms 72.2k 1.17s 85.2% 16.3μs 43.0MiB 98.9% 624B + third-order derivatives 72.2k 70.8ms 5.1% 981ns 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 72.2k 22.1ms 1.6% 306ns 976B 0.0% 0.01B +analyze solution 23 102ms 7.4% 4.42ms 501KiB 1.1% 21.8KiB +rhs_split_nonstiff! 13.3k 8.86ms 0.6% 665ns 672B 0.0% 0.05B + hyperbolic 13.3k 6.82ms 0.5% 512ns 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 13.3k 2.03ms 0.1% 153ns 672B 0.0% 0.05B +────────────────────────────────────────────────────────────────────────────────────── + + + + + + +No with fixed time step dt = 1e-2: + + + +no IMEX +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 6.33s / 57.1% 0.99GiB / 1.3% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs! 22.0k 3.46s 96.0% 157μs 13.1MiB 99.8% 624B + third-order derivatives 22.0k 1.61s 44.6% 73.1μs 0.00B 0.0% 0.00B + hyperbolic 22.0k 1.49s 41.2% 67.5μs 0.00B 0.0% 0.00B + source terms 22.0k 355ms 9.8% 16.1μs 13.1MiB 99.8% 624B + ~rhs!~ 22.0k 13.5ms 0.4% 615ns 1.25KiB 0.0% 0.06B +analyze solution 1 146ms 4.0% 146ms 21.9KiB 0.2% 21.9KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in nonstiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.34s / 39.4% 0.98GiB / 0.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 22.2k 1.70s 99.0% 76.3μs 672B 0.2% 0.03B + third-order derivatives 22.2k 1.69s 98.7% 76.1μs 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 22.2k 5.55ms 0.3% 250ns 672B 0.2% 0.03B +rhs_split_nonstiff! 601 15.1ms 0.9% 25.0μs 367KiB 89.4% 626B + source terms 601 12.9ms 0.8% 21.5μs 366KiB 89.2% 624B + hyperbolic 601 1.41ms 0.1% 2.35μs 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 601 713μs 0.0% 1.19μs 976B 0.2% 1.62B +analyze solution 1 1.86ms 0.1% 1.86ms 42.7KiB 10.4% 42.7KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in stiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.29s / 48.3% 0.99GiB / 1.3% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 22.2k 2.06s 99.7% 92.9μs 13.2MiB 99.8% 624B + third-order derivatives 22.2k 1.70s 82.1% 76.4μs 0.00B 0.0% 0.00B + source terms 22.2k 356ms 17.2% 16.0μs 13.2MiB 99.8% 624B + ~rhs_split_stiff!~ 22.2k 8.90ms 0.4% 401ns 976B 0.0% 0.04B +analyze solution 1 3.62ms 0.2% 3.62ms 21.9KiB 0.2% 21.9KiB +rhs_split_nonstiff! 601 1.61ms 0.1% 2.68μs 672B 0.0% 1.12B + hyperbolic 601 1.16ms 0.1% 1.93μs 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 601 448μs 0.0% 746ns 672B 0.0% 1.12B +────────────────────────────────────────────────────────────────────────────────────── + + +Also for dt = 1e-1 the solution for IMEX already looks bad(ish), +will for split_ode = Val{false}() it looks still good. +""" \ No newline at end of file diff --git a/examples/kdv_1d/kdv_1d_narrow_stencil.jl b/examples/kdv_1d/kdv_1d_narrow_stencil.jl index e5aae1c87..756f8482a 100644 --- a/examples/kdv_1d/kdv_1d_narrow_stencil.jl +++ b/examples/kdv_1d/kdv_1d_narrow_stencil.jl @@ -22,7 +22,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index fcddc10f2..03b7df88a 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -74,7 +74,7 @@ export LinearDispersionRelation, EulerEquations1D, wave_speed export prim2prim, prim2cons, cons2prim, prim2phys, waterheight_total, waterheight, velocity, momentum, discharge, - gravity, + gravity, have_stiff_terms bathymetry, still_water_surface, energy_total, entropy, lake_at_rest_error, energy_total_modified, entropy_modified, diff --git a/src/equations/equations.jl b/src/equations/equations.jl index 5c3af1c88..324df8287 100644 --- a/src/equations/equations.jl +++ b/src/equations/equations.jl @@ -265,6 +265,21 @@ Return the gravitational acceleration ``g`` for a given set of `equations`. return equations.gravity end +""" + DispersiveShallowWater.have_stiff_terms(equations) + +Returns `Val{true}()` if the equations have stiff terms that benefit from +implicit time integration methods and `Val{false}()` otherwise (default). + +This trait is used to determine whether to create a `SplitFunction` in +[`semidiscretize`](@ref) for IMEX time integration methods. + +!!! note "Implementation details" + This function is used for internal dispatch to determine the appropriate + ODE problem formulation. +""" +have_stiff_terms(::AbstractEquations) = Val{false}() + """ energy_total(q, equations) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 26a6afba7..520e61df0 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -53,6 +53,9 @@ function KdVEquation1D(; gravity, D = 1.0, eta0 = 0.0) KdVEquation1D(gravity, D, eta0) end +# KdV equations have stiff third-order derivative terms that benefit IMEX methods +have_stiff_terms(::KdVEquation1D) = Val{true}() + """ initial_condition_convergence_test(x, t, equations::KdVEquation1D, mesh) @@ -130,6 +133,7 @@ function create_cache(mesh, equations::KdVEquation1D, return cache end +""" function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, ::BoundaryConditionPeriodic, source_terms, solver, cache) eta, = q.x @@ -250,3 +254,69 @@ function rhs_split_nonstiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_c return nothing end +""" + +function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache, + mode::Symbol = :full) + eta, = q.x + deta, = dq.x + + (; c_0, c_1, DD) = cache + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + D1 = solver.D1.central + else + D1 = solver.D1 + end + + # Initialize deta based on mode + if mode == :full || mode == :stiff + @trixi_timeit timer() "third-order derivatives" begin + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + mul!(tmp_1, solver.D1.minus, eta) + mul!(tmp_2, solver.D1.central, tmp_1) + mul!(tmp_1, solver.D1.plus, tmp_2) + else + mul!(tmp_1, solver.D3, eta) + end + + # add stiff part + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx + @.. deta = -1/6 * c_0 * DD * tmp_1 + + end + + end + + if mode == :full || mode == :nonstiff + @trixi_timeit timer() "hyperbolic" begin + # eta2 = eta^2 + @.. tmp_1 = eta^2 + + # eta2_x = D1 * eta2 + mul!(tmp_2, D1, tmp_1) + + # eta_x = D1 * eta + mul!(tmp_1, D1, eta) + + # Set or add non-stiff part + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + if mode == :nonstiff + @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) + else # mode == :full + @.. deta -= (c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) # Add to existing stiff part + end + end + + @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations, solver) + + end + + return nothing +end + diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index b9fdc2e2e..671d34d53 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -189,25 +189,20 @@ end function rhs_split_stiff!(dq, q, semi::Semidiscretization, t) @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi - @trixi_timeit timer() "rhs_split_stiff!" rhs_split_stiff!(dq, q, t, mesh, equations, - initial_condition, - boundary_conditions, - source_terms, - solver, - cache) + @trixi_timeit timer() "rhs_split_stiff!" rhs!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, solver, + cache, :stiff) return nothing end function rhs_split_nonstiff!(dq, q, semi::Semidiscretization, t) @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi - @trixi_timeit timer() "rhs_split_nonstiff!" rhs_split_nonstiff!(dq, q, t, mesh, - equations, - initial_condition, - boundary_conditions, - source_terms, - solver, - cache) + @trixi_timeit timer() "rhs_split_nonstiff!" rhs!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, + solver, cache, :nonstiff) return nothing end @@ -239,22 +234,50 @@ function check_bathymetry(equations::AbstractShallowWaterEquations, q0) end """ - semidiscretize(semi::Semidiscretization, tspan; no_splitform = true) + semidiscretize(semi::Semidiscretization, tspan; split_ode = have_stiff_terms(semi.equations)) Wrap the semidiscretization `semi` as an ODE problem in the time interval `tspan` that can be passed to `solve` from the [SciML ecosystem](https://diffeq.sciml.ai/latest/). + +If `split_ode` is `Val{false}()`, a regular `ODEFunction` is created. +If `split_ode` is `Val{true}()`, a `SplitFunction` is created for IMEX time integration if available. +By default, `split_ode` is determined by the [`DispersiveShallowWater.have_stiff_terms`](@ref) trait. """ -function semidiscretize(semi::Semidiscretization, tspan; no_splitform = true) +function semidiscretize(semi::Semidiscretization, tspan; + split_ode = have_stiff_terms(semi.equations)) q0 = compute_coefficients(semi.initial_condition, first(tspan), semi) check_bathymetry(semi.equations, q0) iip = true # is-inplace, i.e., we modify a vector when calling rhs! - if no_splitform - ode = ODEProblem{iip}(rhs!, q0, tspan, semi) - else - ode = ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, - tspan, semi) + return _semidiscretize_ode(split_ode, q0, tspan, semi, iip) +end + +# Type-stable dispatch based on split_ode trait +function _semidiscretize_ode(::Val{false}, q0, tspan, semi, iip) + return ODEProblem{iip}(rhs!, q0, tspan, semi) +end + +function _semidiscretize_ode(::Val{true}, q0, tspan, semi, iip) + _check_split_rhs_implementation(semi) + return ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, tspan, + semi) +end + +function _check_split_rhs_implementation(semi) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + equation_name = get_name(equations) + args = (nothing, nothing, nothing, mesh, equations, initial_condition, boundary_conditions, source_terms, solver, cache) + + # # Check if methods are applicable + if !applicable(rhs!, args..., :stiff) + throw(ArgumentError("Split RHS method with :stiff argument not implemented for $equation_name.")) end - return ode + + if !applicable(rhs!, args..., :nonstiff) + throw(ArgumentError("Split RHS method with :nonstiff argument not implemented for $equation_name.")) + end + + return nothing end """ diff --git a/test/Project.toml b/test/Project.toml index d3ee423c8..06b2600b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" OrdinaryDiffEqLowStorageRK = "b0944070-b475-4768-8dec-fb6eb410534d" OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SummationByPartsOperators = "9f78cca6-572e-554e-b819-917d2f1cf240" @@ -22,6 +23,7 @@ JET = "0.9.9" OrdinaryDiffEqLowStorageRK = "1.1" OrdinaryDiffEqRosenbrock = "1.3" OrdinaryDiffEqTsit5 = "1.1" +OrdinaryDiffEqSDIRK = "1.1" Plots = "1.25" SparseArrays = "1" SummationByPartsOperators = "0.5.63" diff --git a/test/test_kdv_1d.jl b/test/test_kdv_1d.jl index 866eaa52f..f0a9b9239 100644 --- a/test/test_kdv_1d.jl +++ b/test/test_kdv_1d.jl @@ -63,10 +63,10 @@ end @testitem "kdv_1d_IMEX" setup=[Setup, KdVEquation1D] begin @test_trixi_include(joinpath(EXAMPLES_DIR, "kdv_1d_IMEX.jl"), tspan=(0.0, 5.0), - l2=[0.0007835879713461127], - linf=[0.0005961613764722262], - cons_error=[4.440892098500626e-16], - change_waterheight=-4.440892098500626e-16) + l2=[0.004952174509850488], + linf=[0.003962890861977875], + cons_error=[2.220446049250313e-15], + change_waterheight=-2.220446049250313e-15) @test_allocations_splitform(semi, sol, allocs=5_000) end diff --git a/test/test_unit.jl b/test/test_unit.jl index 6d9d024a0..119c7b2a8 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -83,6 +83,7 @@ end solver = Solver(mesh, 4) semi_flat = Semidiscretization(mesh, equations_flat, initial_condition, solver) @test_throws ArgumentError semidiscretize(semi_flat, (0.0, 1.0)) + @test_throws ArgumentError semidiscretize(semi, (0.0, 1.0), split_ode = Val{true}()) end @testitem "Boundary conditions" setup=[Setup] begin From bc57c2b517a56dae1eee23de996b7c784cb8217a Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Sat, 24 May 2025 00:05:00 +0200 Subject: [PATCH 07/11] formatting --- examples/kdv_1d/kdv_1d_IMEX.jl | 2 -- examples/kdv_1d/kdv_1d_manufactured.jl | 2 +- src/DispersiveShallowWater.jl | 10 +++---- src/equations/kdv_1d.jl | 39 ++++++++++++-------------- src/semidiscretization.jl | 5 ++-- 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/examples/kdv_1d/kdv_1d_IMEX.jl b/examples/kdv_1d/kdv_1d_IMEX.jl index 24c251c77..6ca8c8c79 100644 --- a/examples/kdv_1d/kdv_1d_IMEX.jl +++ b/examples/kdv_1d/kdv_1d_IMEX.jl @@ -39,8 +39,6 @@ analysis_callback = AnalysisCallback(semi; interval = 100, callbacks = CallbackSet(analysis_callback, summary_callback) saveat = range(tspan..., length = 100) - alg = KenCarp4() # use an IMEX method sol = solve(ode, alg, abstol = 1e-7, reltol = 1e-7, save_everystep = false, callback = callbacks, saveat = saveat) - diff --git a/examples/kdv_1d/kdv_1d_manufactured.jl b/examples/kdv_1d/kdv_1d_manufactured.jl index 7474c80dc..9b6485944 100644 --- a/examples/kdv_1d/kdv_1d_manufactured.jl +++ b/examples/kdv_1d/kdv_1d_manufactured.jl @@ -166,4 +166,4 @@ rhs_split_nonstiff! 601 1.61ms 0.1% 2.68μs 672B 0.0% Also for dt = 1e-1 the solution for IMEX already looks bad(ish), will for split_ode = Val{false}() it looks still good. -""" \ No newline at end of file +""" diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index 03b7df88a..f038fba53 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -74,11 +74,11 @@ export LinearDispersionRelation, EulerEquations1D, wave_speed export prim2prim, prim2cons, cons2prim, prim2phys, waterheight_total, waterheight, velocity, momentum, discharge, - gravity, have_stiff_terms - bathymetry, still_water_surface, - energy_total, entropy, lake_at_rest_error, - energy_total_modified, entropy_modified, - hamiltonian + gravity, have_stiff_terms +bathymetry, still_water_surface, +energy_total, entropy, lake_at_rest_error, +energy_total_modified, entropy_modified, +hamiltonian export Mesh1D, xmin, xmax, nnodes diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 520e61df0..94587edd9 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -287,36 +287,33 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, # add stiff part # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx - @.. deta = -1/6 * c_0 * DD * tmp_1 - + @.. deta = -1 / 6 * c_0 * DD * tmp_1 end - end - + if mode == :full || mode == :nonstiff @trixi_timeit timer() "hyperbolic" begin - # eta2 = eta^2 - @.. tmp_1 = eta^2 - - # eta2_x = D1 * eta2 - mul!(tmp_2, D1, tmp_1) - - # eta_x = D1 * eta - mul!(tmp_1, D1, eta) - - # Set or add non-stiff part - # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) - if mode == :nonstiff - @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) + # eta2 = eta^2 + @.. tmp_1 = eta^2 + + # eta2_x = D1 * eta2 + mul!(tmp_2, D1, tmp_1) + + # eta_x = D1 * eta + mul!(tmp_1, D1, eta) + + # Set or add non-stiff part + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + if mode == :nonstiff + @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) else # mode == :full @.. deta -= (c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) # Add to existing stiff part end end - - @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations, solver) - + + @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, + equations, solver) end return nothing end - diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index 671d34d53..66efb5a0f 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -257,7 +257,7 @@ function _semidiscretize_ode(::Val{false}, q0, tspan, semi, iip) end function _semidiscretize_ode(::Val{true}, q0, tspan, semi, iip) - _check_split_rhs_implementation(semi) + _check_split_rhs_implementation(semi) return ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, tspan, semi) end @@ -266,7 +266,8 @@ function _check_split_rhs_implementation(semi) @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi equation_name = get_name(equations) - args = (nothing, nothing, nothing, mesh, equations, initial_condition, boundary_conditions, source_terms, solver, cache) + args = (nothing, nothing, nothing, mesh, equations, initial_condition, + boundary_conditions, source_terms, solver, cache) # # Check if methods are applicable if !applicable(rhs!, args..., :stiff) From 062f4d993f215da8f8d385e6bb13c754c8c6c23a Mon Sep 17 00:00:00 2001 From: Collin Wittenstein Date: Sat, 24 May 2025 00:11:51 +0200 Subject: [PATCH 08/11] added missing "," in export list --- src/DispersiveShallowWater.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index f038fba53..531683a53 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -74,11 +74,11 @@ export LinearDispersionRelation, EulerEquations1D, wave_speed export prim2prim, prim2cons, cons2prim, prim2phys, waterheight_total, waterheight, velocity, momentum, discharge, - gravity, have_stiff_terms -bathymetry, still_water_surface, -energy_total, entropy, lake_at_rest_error, -energy_total_modified, entropy_modified, -hamiltonian + gravity, have_stiff_terms, + bathymetry, still_water_surface, + energy_total, entropy, lake_at_rest_error, + energy_total_modified, entropy_modified, + hamiltonian export Mesh1D, xmin, xmax, nnodes From f1c24e99be29779c05650ce0dc46a27879bf0ed1 Mon Sep 17 00:00:00 2001 From: Collin Wittenstein <126870995+cwittens@users.noreply.github.com> Date: Sun, 25 May 2025 00:25:05 +0200 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Joshua Lampert <51029046+JoshuaLampert@users.noreply.github.com> --- src/equations/kdv_1d.jl | 2 +- src/semidiscretization.jl | 2 +- test/test_kdv_1d.jl | 2 +- test/test_util.jl | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 94587edd9..49a0b6524 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -53,7 +53,7 @@ function KdVEquation1D(; gravity, D = 1.0, eta0 = 0.0) KdVEquation1D(gravity, D, eta0) end -# KdV equations have stiff third-order derivative terms that benefit IMEX methods +# KdV equations have stiff third-order derivative terms that benefit from IMEX methods have_stiff_terms(::KdVEquation1D) = Val{true}() """ diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index 66efb5a0f..877e7ebc3 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -269,7 +269,7 @@ function _check_split_rhs_implementation(semi) args = (nothing, nothing, nothing, mesh, equations, initial_condition, boundary_conditions, source_terms, solver, cache) - # # Check if methods are applicable + # Check if methods are applicable if !applicable(rhs!, args..., :stiff) throw(ArgumentError("Split RHS method with :stiff argument not implemented for $equation_name.")) end diff --git a/test/test_kdv_1d.jl b/test/test_kdv_1d.jl index f0a9b9239..a627c30ac 100644 --- a/test/test_kdv_1d.jl +++ b/test/test_kdv_1d.jl @@ -68,5 +68,5 @@ end cons_error=[2.220446049250313e-15], change_waterheight=-2.220446049250313e-15) - @test_allocations_splitform(semi, sol, allocs=5_000) + @test_allocations_split_ode(semi, sol, allocs=5_000) end diff --git a/test/test_util.jl b/test/test_util.jl index 2abd17428..110d04b3a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -177,13 +177,13 @@ macro test_allocations(semi, sol, allocs) end """ - @test_allocations_splitform(semi, sol, allocs) + @test_allocations_split_ode(semi, sol, allocs) Test that the memory allocations of `DispersiveShallowWater.rhs_split_stiff!` and `DispersiveShallowWater.rhs_split_nonstiff!` are below `allocs` (e.g., from type instabilities). """ -macro test_allocations_splitform(semi, sol, allocs) +macro test_allocations_split_ode(semi, sol, allocs) quote t = $sol.t[end] q = $sol.u[end] From b95b2d278f7e1c2e287d91695d6ac0d031f19d18 Mon Sep 17 00:00:00 2001 From: Joshua Lampert Date: Thu, 27 Nov 2025 11:04:15 +0100 Subject: [PATCH 10/11] fix merge conflict --- src/equations/kdv_1d.jl | 8 +++++--- test/test_kdv_1d.jl | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index b8da1d9a4..99cadd2a3 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -303,7 +303,7 @@ function rhs_split_stiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_cond mul!(tmp_1, solver.D3, eta) end - # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx @.. deta = -1 / 6 * c_0 * DD * tmp_1 end @@ -339,7 +339,7 @@ function rhs_split_nonstiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_c # eta_x = D1 * eta mul!(tmp_1, D1, eta) - # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) end @@ -397,7 +397,7 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, mul!(tmp_1, D1, eta) # Set or add non-stiff part - # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) if mode == :nonstiff @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) else # mode == :full @@ -410,6 +410,8 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, end return nothing +end + """ energy_total(q, equations::KdVEquation1D) diff --git a/test/test_kdv_1d.jl b/test/test_kdv_1d.jl index b81a85ec0..1f39dc057 100644 --- a/test/test_kdv_1d.jl +++ b/test/test_kdv_1d.jl @@ -97,5 +97,5 @@ end cons_error=[2.220446049250313e-15], change_waterheight=-2.220446049250313e-15) - @test_allocations_split_ode(semi, sol, allocs=5_000) + @test_allocations(DispersiveShallowWater.rhs!, semi, sol, allocs=5_000) end From aa2edf7dccd57b3f22b6e633805c283eb3dc10df Mon Sep 17 00:00:00 2001 From: Joshua Lampert Date: Thu, 27 Nov 2025 11:11:10 +0100 Subject: [PATCH 11/11] use SplitODEProblem --- examples/kdv_1d/kdv_1d_IMEX.jl | 2 +- src/DispersiveShallowWater.jl | 3 ++- src/semidiscretization.jl | 12 ++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/kdv_1d/kdv_1d_IMEX.jl b/examples/kdv_1d/kdv_1d_IMEX.jl index 6ca8c8c79..a9c3da4f7 100644 --- a/examples/kdv_1d/kdv_1d_IMEX.jl +++ b/examples/kdv_1d/kdv_1d_IMEX.jl @@ -3,7 +3,7 @@ using DispersiveShallowWater using SummationByPartsOperators: upwind_operators, periodic_derivative_operator ############################################################################### -# Semidiscretization of the KdV equation +# Semidiscretization of the KdV equation equations = KdVEquation1D(gravity = 9.81, D = 1.0) initial_condition = initial_condition_convergence_test diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index e05d47ce6..3d7a9e6fa 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -31,7 +31,8 @@ using RecursiveArrayTools: ArrayPartition using Reexport: @reexport using Roots: AlefeldPotraShi, find_zero -using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution, SplitFunction +using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution, SplitFunction, + SplitODEProblem import SciMLBase: u_modified! @reexport using StaticArrays: SVector diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index f9463ee57..e95ac5cd9 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -92,10 +92,10 @@ end """ check_solver(equations, solver, boundary_conditions) -Check that the `solver` is compatible with the given `equations` and +Check that the `solver` is compatible with the given `equations` and `boundary_conditions`. The default implementation performs no checks. -Specific equation types can override this method to validate that -required derivative operators are present (e.g., some equations +Specific equation types can override this method to validate that +required derivative operators are present (e.g., some equations require `D2` or `D3` to be non-`nothing`). Throws an `ArgumentError` if the solver is incompatible. @@ -279,15 +279,15 @@ function semidiscretize(semi::Semidiscretization, tspan; return _semidiscretize_ode(split_ode, q0, tspan, semi, iip) end -# Type-stable dispatch based on split_ode trait +# Type-stable dispatch based on split_ode trait function _semidiscretize_ode(::Val{false}, q0, tspan, semi, iip) return ODEProblem{iip}(rhs!, q0, tspan, semi) end function _semidiscretize_ode(::Val{true}, q0, tspan, semi, iip) _check_split_rhs_implementation(semi) - return ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, tspan, - semi) + return SplitODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, + tspan, semi) end function _check_split_rhs_implementation(semi)