From 2aa615c9094916750ad5b48e95053e4e7abcb222 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:07:41 -0600 Subject: [PATCH 01/27] add disco to core --- .../src/OrdinaryDiffEqCore.jl | 1 + lib/OrdinaryDiffEqCore/src/disco.jl | 77 ++++++ .../src/integrators/controllers.jl | 6 + .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 3 +- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 256 ++++++++++++++++++ 6 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 lib/OrdinaryDiffEqCore/src/disco.jl create mode 100644 lib/OrdinaryDiffEqCore/test/disco_tests.jl diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index fbfd8a743b7..29acc5f53c2 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -162,6 +162,7 @@ include("cache_utils.jl") include("initialize_dae.jl") include("perform_step/composite_perform_step.jl") +include("disco.jl") include("dense/generic_dense.jl") diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl new file mode 100644 index 00000000000..80bdfe143c4 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -0,0 +1,77 @@ +function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to test + breakpointθ = find_discontinuity(u, uprev, integrator, cache) + dt = integrator.dt + t = integrator.t + if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 + #println("Discontinuity detected at t = ", t + breakpointθ * dt) + integrator.dt = breakpointθ * dt + integrator.disco_dt_set = true + end +end + +function find_discontinuity(u, uprev, integrator, cache) + cb = integrator.opts.callback + cb === nothing && return -1 + isempty(cb.continuous_callbacks) && return -1 + + disco_exists = false; + for i in cb.continuous_callbacks + if (i.is_discontinuity) + disco_exists = true + break + end + end + !disco_exists && return -1 + p = integrator.p + t = integrator.t + dt = integrator.dt + breakpointθ = -one(dt) + prob = nothing + for i in cb.continuous_callbacks + if (!(i.is_discontinuity)) + continue + end + out_prev = nothing + out_curr = nothing + is_inplace = DiffEqBase.isinplace(i.condition, 4) + if is_inplace + out_prev = similar(u) + i.condition(out_prev, uprev, t, integrator) + out_curr = similar(u) + i.condition(out_curr, u, t + dt, integrator) + is_inplace = true + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + is_inplace = false + end + for (idx, (f0, f1)) in enumerate(zip(out_prev, out_curr)) + if (f0 * f1 < zero(f0)) + function zero_func(θ, p) + u₁ = similar(u) + _ode_interpolant!(u₁, θ, dt, uprev, u, integrator.k, cache, + nothing, Val{0}, nothing) + + if is_inplace + out = similar(u) + i.condition(out, u₁, t + θ * dt, integrator) + else + out = i.condition(u₁, t + θ * dt, integrator) + end + out[idx] + end + if prob === nothing + prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) + else + prob = remake(prob; f=zero_func) + end + sol = solve(prob; bracket=[zero(dt), one(dt)]) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + end + breakpointθ +end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 5ad0824b66c..6325a9bb17c 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -948,6 +948,12 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache + + if (integrator.disco_dt_set) + println("using fixed dt from discontinuity handling") + integrator.disco_dt_set = false + return integrator.dt + end return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index d97636952bd..6234de5143b 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -108,6 +108,7 @@ mutable struct ODEIntegrator{ dtcache::tType dtchangeable::Bool dtpropose::tType + disco_dt_set::Bool tdir::tdirType eigen_est::eigenType EEst::EEstT diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 976402705f8..b9ce19f7622 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -32,6 +32,7 @@ function SciMLBase.__init( save_everystep = isempty(saveat), save_on = true, save_discretes = true, + disco_dt_set = false, save_start = save_everystep || isempty(saveat) || saveat isa Number || prob.tspan[1] in saveat, save_end = nothing, @@ -667,7 +668,7 @@ function SciMLBase.__init( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, - dtpropose, tdir, eigen_est, EEst, + dtpropose, disco_dt_set, tdir, eigen_est, EEst, # TODO vvv remove these QT(qoldinit), q11, erracc, dtacc, diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl new file mode 100644 index 00000000000..1a26db54328 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -0,0 +1,256 @@ +using OrdinaryDiffEqCore +using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqRosenbrock +using OrdinaryDiffEqBDF + +#test example discontinuous at u = 1 +f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] +u0 = [0.1] +tspan = (0.0, 1.5) +prob = ODEProblem(f, u0, tspan) + +#define callback +condition(u, t, integrator) = u[1] - 1 +function affect!(integrator) + integrator.u[1] += 10 + println("Callback fired at t = ", integrator.t) +end +cb = ContinuousCallback(condition, affect!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +#fixed order solve +sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) + +rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) + +bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) + +#two discontinuity functions +function f(u, p, t) + if u[1] < 1 + [2u[1]] # region 1: grows to hit u = 1 + elseif u[1] < 2 + [u[1] + 0.2] # region 2: continues increasing to hit u = 2 + else + [-4u[1] + 12] # region 3: after 2, moves toward u ≈ 3 + end +end + +u0 = [0.1] +tspan = (0.0, 2.5) +prob = ODEProblem(f, u0, tspan) + +#define callbacks +condition1(u, t, integrator) = u[1] - 1 +function affect1!(integrator) + #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) + +condition2(u, t, integrator) = u[1] - 2 +function affect2!(integrator) + #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) +cb = CallbackSet(cb1, cb2) + +#disco solve +sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +#fixed order solve +sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) + +rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) + +bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) + + +# multiple exponential regions with sharp transitions +function f_multi_exp!(du, u, p, t) + if u[1] < 0.3 + du[1] = 3 * exp(3 * u[1]) # very steep exponential + elseif u[1] < 0.8 + du[1] = exp(u[1]) # slower exponential + else + du[1] = u[1] # linear + end +end + +u0_multi = [0.05] +tspan_multi = (0.0, 1.5) +prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) + +#define callbacks +cond_multi_1(u, t, integrator) = u[1] - 0.3 +function affect_multi_1!(integrator) + println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) + +cond_multi_2(u, t, integrator) = u[1] - 0.8 +function affect_multi_2!(integrator) + println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) +cb_multi = CallbackSet(cb_multi_1, cb_multi_2) + +#disco solve +sol_disco = solve(prob_multi, RadauIIA5(is_disco = true); callback=cb_multi, reltol=1e-7, abstol=1e-9) +#fixed order solve +sol_no_disco = solve(prob_multi, RadauIIA5(is_disco = false); callback=cb_multi, reltol = 1e-7, abstol = 1e-9) + +# 2D system with exponential coupling and discontinuity +function f_2d_exp!(du, u, p, t) + if u[1] + u[2] < 1.0 + du[1] = 2 * exp(u[1]) - u[2] + du[2] = -3 * u[1] + 4 * exp(u[2]) + else + du[1] = u[1] + du[2] = u[2] + end +end + +u0_2d = [0.1, 0.2] +tspan_2d = (0.0, 2.0) +prob_2d = ODEProblem(f_2d_exp!, u0_2d, tspan_2d) + +#define callback +cond_2d(u, t, integrator) = u[1] + u[2] - 1.0 +function affect_2d!(integrator) + println("2D exponential discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") + @test 0.98 < integrator.u[1] + integrator.u[2] < 1.02 +end +cb_2d = ContinuousCallback(cond_2d, affect_2d!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob_2d, RadauIIA5(is_disco = true); callback=cb_2d, reltol=1e-8, abstol=1e-10) +#fixed order solve +sol_no_disco = solve(prob_2d, RadauIIA5(is_disco = false); callback=cb_2d, reltol = 1e-8, abstol = 1e-10) + +# very stiff discontinuous system +function f_stiff_disc!(du, u, p, t) + λ = p[1] # stiffness parameter + if u[1] < 0.5 + du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing + else + du[1] = u[1] + end +end + +u0_stiff = [0.1] +tspan_stiff = (0.0, 3.0) +prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) + +#define callback +cond_stiff(u, t, integrator) = u[1] - 0.5 +function affect_stiff!(integrator) + println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob_stiff, RadauIIA5(is_disco = true); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +#fixed order solve +sol_no_disco = solve(prob_stiff, RadauIIA5(is_disco = false); callback=cb_stiff, reltol = 1e-9, abstol = 1e-11) + +# multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) +function f_many_disc!(du, u, p, t) + du[1] = u[1] + 1 # simple linear growth +end + +u0_many = [0.0] +tspan_many = (0.0, 1.0) +prob_many = ODEProblem(f_many_disc!, u0_many, tspan_many) + +# create 5 discontinuities spaced 1e-6 apart +disc_values = [0.1 + i * 1e-6 for i = 0:4] + +# define callbacks for each discontinuity +cbs_many = [] +for (i, disc_val) in enumerate(disc_values) + local cond_func(u, t, integrator) = u[1] - disc_val + function affect_func!(integrator) + println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") + end + push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) +end +cb_many = CallbackSet(cbs_many...) + +#disco solve +sol_disco = solve(prob_many, RadauIIA5(is_disco = true); callback=cb_many, reltol=1e-10, abstol=1e-12) +#fixed order solve +sol_no_disco = solve(prob_many, RadauIIA5(is_disco = false); callback=cb_many, reltol=1e-10, abstol=1e-12) + +# discontinuous DAE with mass matrix +# System: M * du/dt = f(u, p, t) +# du[1]/dt = u[2] - u[1] +# 0 = u[1] + u[2] - 1 (algebraic constraint) +function f_dae_disc!(du, u, p, t) + if u[1] < 0.5 + du[1] = 2 * u[2] - u[1] + du[2] = u[1] + u[2] - 1 # algebraic constraint + else + du[1] = -u[1] + u[2] + du[2] = u[1] + u[2] - 1 # algebraic constraint + end +end + +u0_dae = [0.2, 0.8] # consistent with constraint u[1] + u[2] = 1 +tspan_dae = (0.0, 2.0) + +M_dae = [1.0 0.0; 0.0 0.0] + +f_dae_func = ODEFunction(f_dae_disc!; mass_matrix=M_dae) +prob_dae = ODEProblem(f_dae_func, u0_dae, tspan_dae) + +cond_dae(u, t, integrator) = u[1] - 0.5 +function affect_dae!(integrator) + #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") +end +cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) + + +radau_no_disco = solve(prob_dae, RadauIIA5(is_disco = false); callback=cb_dae, reltol=1e-8, abstol=1e-10) + #83.500 μs (769 allocations: 35.72 KiB) +radau_disco = solve(prob_dae, RadauIIA5(is_disco = true); callback=cb_dae, reltol=1e-8, abstol=1e-10) + # 119.417 μs (1273 allocations: 55.42 KiB) +rodas_no_disco = solve(prob_dae, Rodas5P(); callback = cb_dae, reltol = 1e-6) +#= SciMLBase.DEStats +Number of function 1 evaluations: 312 +Number of function 2 evaluations: 0 +Number of W matrix evaluations: 34 +Number of linear solves: 272 +Number of Jacobians created: 19 +Number of nonlinear solver iterations: 0 +Number of nonlinear solver convergence failures: 0 +Number of fixed-point solver iterations: 0 +Number of fixed-point solver convergence failures: 0 +Number of rootfind condition calls: 213 +Number of accepted steps: 19 +Number of rejected steps: 15 =# +# 98.167 μs (550 allocations: 26.92 KiB) +rodas_disco = solve(prob_dae, Rodas5P(is_disco = true); callback = cb_dae, reltol = 1e-6) +#= SciMLBase.DEStats +Number of function 1 evaluations: 312 +Number of function 2 evaluations: 0 +Number of W matrix evaluations: 34 +Number of linear solves: 272 +Number of Jacobians created: 19 +Number of nonlinear solver iterations: 0 +Number of nonlinear solver convergence failures: 0 +Number of fixed-point solver iterations: 0 +Number of fixed-point solver convergence failures: 0 +Number of rootfind condition calls: 213 +Number of accepted steps: 19 +Number of rejected steps: 15 =# +# 97.541 μs (550 allocations: 26.92 KiB) +bdf_no_disco = solve(prob_dae, FBDF(); callback = cb_dae, reltol = 1e-6) +bdf_disco = solve(prob_dae, FBDF(is_disco = true); callback = cb_dae, reltol = 1e-6) \ No newline at end of file From f485bdc59b13261a5eb7e958e13be9f6eea3b8bf Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:17:16 -0600 Subject: [PATCH 02/27] radau version --- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 5 ++--- lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl | 2 +- lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 6 ++++-- lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl | 8 ++++++++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 1a26db54328..3b27868de8c 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,4 +1,3 @@ -using OrdinaryDiffEqCore using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock using OrdinaryDiffEqBDF @@ -13,15 +12,15 @@ prob = ODEProblem(f, u0, tspan) condition(u, t, integrator) = u[1] - 1 function affect!(integrator) integrator.u[1] += 10 - println("Callback fired at t = ", integrator.t) end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) #disco solve sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +# 291.292 μs (8449 allocations: 266.47 KiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) - +# 335.417 μs (10008 allocations: 311.08 KiB) rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) diff --git a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl index a0cd04b69bf..d77dc456c45 100644 --- a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl +++ b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl @@ -18,7 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, fac_default_gamma, get_current_adaptive_order, get_fsalfirstlast, isfirk, generic_solver_docstring, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier + _process_AD_choice, LinearAliasSpecifier, set_discontinuity using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester isfirk, generic_solver_docstring using SciMLOperators: AbstractSciMLOperator diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index aa433958379..3ee4ecaa422 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -92,6 +92,7 @@ struct RadauIIA5{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD + is_disco::Bool end function RadauIIA5(; @@ -102,7 +103,7 @@ function RadauIIA5(; extrapolant = :dense, fast_convergence_cutoff = 1 // 5, new_W_γdt_cutoff = 1 // 5, controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true, - step_limiter! = trivial_limiter! + step_limiter! = trivial_limiter!, is_disco = false ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -122,7 +123,8 @@ function RadauIIA5(; new_W_γdt_cutoff, controller, step_limiter!, - AD_choice + AD_choice, + is_disco ) end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 499b345d13c..ce535649b5a 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -675,6 +675,10 @@ end integrator.k[4] = z2 integrator.k[5] = z3 end + else + if alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end integrator.fsallast = f(u, p, t + dt) @@ -952,6 +956,10 @@ end integrator.k[4] .= z2 integrator.k[5] .= z3 end + else + if alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end f(fsallast, u, p, t + dt) From cbd2b1c789a9391b1ffab6760ecbf6e9221eb1d2 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:50:14 -0600 Subject: [PATCH 03/27] bdf and rodas --- lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 7 ++++--- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 7 +++++++ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 12 +++++++----- .../src/OrdinaryDiffEqRosenbrock.jl | 2 +- lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl | 7 ++++--- .../src/rosenbrock_perform_step.jl | 7 +++++++ 7 files changed, 31 insertions(+), 13 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index f6e404018da..50c5eb33fb8 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index ad0aba5f193..6aa128a6592 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -575,6 +575,7 @@ struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD + is_disco::Bool end function FBDF(; @@ -583,7 +584,7 @@ function FBDF(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing, tol = nothing, - extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter! + extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!, is_disco = false ) where {MO} AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -594,7 +595,7 @@ function FBDF(; typeof(κ), typeof(tol), typeof(step_limiter!), }( max_order, linsolve, nlsolve, precs, κ, tol, extrapolant, - controller, step_limiter!, AD_choice + controller, step_limiter!, AD_choice, is_disco ) end @@ -841,4 +842,4 @@ function DFBDF(; ) end -@truncate_stacktrace DFBDF +@truncate_stacktrace DFBDF \ No newline at end of file diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index 59deace1b22..7e9cd3f3031 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1276,6 +1276,9 @@ function perform_step!( integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) + set_discontinuity(u, uprev, integrator, cache) + end terk = estimate_terk(integrator, cache, k + 1, Val(max_order), u) fd_weights = calc_finite_difference_weights(ts_tmp, tdt, k, Val(max_order)) @@ -1483,6 +1486,10 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) + set_discontinuity(u, uprev, integrator, cache) + end + estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( atmp, _vec(terk_tmp), _vec(uprev), _vec(u), abstol, reltol, diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 3b27868de8c..0f69ed6d3db 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -15,20 +15,22 @@ function affect!(integrator) end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) -#disco solve sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) # 291.292 μs (8449 allocations: 266.47 KiB) -#fixed order solve sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) # 335.417 μs (10008 allocations: 311.08 KiB) rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) - +# 410.291 μs (16828 allocations: 594.05 KiB) rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) - +# 483.792 μs (17729 allocations: 639.31 KiB) bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) - +# 245.917 μs (20703 allocations: 665.16 KiB) bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +# 269.333 μs (20477 allocations: 663.80 KiB) +@profview for i in 1:1000 + solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +end #two discontinuity functions function f(u, p, t) if u[1] < 1 diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 9ef1b32e2a7..7017e379fad 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_adaptive_order, isWmethod, isfsal, _un calculate_residuals, has_stiff_interpolation, ODEIntegrator, resize_non_user_cache!, _ode_addsteps!, full_cache, DerivativeOrderNotPossibleError, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, copyat_or_push! + _process_AD_choice, LinearAliasSpecifier, copyat_or_push!, set_discontinuity using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools: namify using MacroTools: @capture diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index cc5a11dedd3..c08af7f8f8c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -26,13 +26,14 @@ for (Alg, desc, refs, is_W) in [ step_limiter!::StepLimiter stage_limiter!::StageLimiter autodiff::AD + is_disco::Bool end function $Alg(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!, - stage_limiter! = trivial_limiter! + stage_limiter! = trivial_limiter!, is_disco = false ) AD_choice, chunk_size, diff_type = _process_AD_choice( autodiff, chunk_size, diff_type @@ -41,10 +42,10 @@ for (Alg, desc, refs, is_W) in [ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac), typeof(step_limiter!), - typeof(stage_limiter!), + typeof(stage_limiter!) }( linsolve, precs, step_limiter!, - stage_limiter!, AD_choice + stage_limiter!, AD_choice, is_disco ) end end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 0ad836ac446..038c06fa01c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1384,6 +1384,10 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end if integrator.opts.calck @@ -1524,6 +1528,9 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end if integrator.opts.calck From aff9da0d0c899f67cbc4dd7bdaf86df4fd28c853 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sun, 1 Mar 2026 19:38:57 -0600 Subject: [PATCH 04/27] refactor disco into controllers --- .../src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 5 +-- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 6 --- lib/OrdinaryDiffEqCore/src/disco.jl | 18 ++------- .../src/integrators/controllers.jl | 37 ++++++++++++++++++- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 29 ++++++++++----- .../src/OrdinaryDiffEqFIRK.jl | 2 +- lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 4 +- .../src/firk_perform_step.jl | 8 ---- .../src/OrdinaryDiffEqRosenbrock.jl | 2 +- .../src/algorithms.jl | 5 +-- .../src/rosenbrock_perform_step.jl | 7 ---- 12 files changed, 68 insertions(+), 57 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index 50c5eb33fb8..f6e404018da 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity + _ode_addsteps!, DerivativeOrderNotPossibleError using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index 6aa128a6592..d6433237802 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -575,7 +575,6 @@ struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD - is_disco::Bool end function FBDF(; @@ -584,7 +583,7 @@ function FBDF(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing, tol = nothing, - extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!, is_disco = false + extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter! ) where {MO} AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -595,7 +594,7 @@ function FBDF(; typeof(κ), typeof(tol), typeof(step_limiter!), }( max_order, linsolve, nlsolve, precs, κ, tol, extrapolant, - controller, step_limiter!, AD_choice, is_disco + controller, step_limiter!, AD_choice ) end diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index 7e9cd3f3031..bee908fe82d 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1276,9 +1276,6 @@ function perform_step!( integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) - set_discontinuity(u, uprev, integrator, cache) - end terk = estimate_terk(integrator, cache, k + 1, Val(max_order), u) fd_weights = calc_finite_difference_weights(ts_tmp, tdt, k, Val(max_order)) @@ -1486,9 +1483,6 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) - set_discontinuity(u, uprev, integrator, cache) - end estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 80bdfe143c4..20698ade0f1 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -4,24 +4,16 @@ function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to te t = integrator.t if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 #println("Discontinuity detected at t = ", t + breakpointθ * dt) - integrator.dt = breakpointθ * dt - integrator.disco_dt_set = true + return breakpointθ * dt end + return -1 end function find_discontinuity(u, uprev, integrator, cache) + println("Finding discontinuity...") cb = integrator.opts.callback cb === nothing && return -1 isempty(cb.continuous_callbacks) && return -1 - - disco_exists = false; - for i in cb.continuous_callbacks - if (i.is_discontinuity) - disco_exists = true - break - end - end - !disco_exists && return -1 p = integrator.p t = integrator.t dt = integrator.dt @@ -49,9 +41,7 @@ function find_discontinuity(u, uprev, integrator, cache) if (f0 * f1 < zero(f0)) function zero_func(θ, p) u₁ = similar(u) - _ode_interpolant!(u₁, θ, dt, uprev, u, integrator.k, cache, - nothing, Val{0}, nothing) - + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) if is_inplace out = similar(u) i.condition(out, u₁, t + θ * dt, integrator) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 6325a9bb17c..a9f0cb211c8 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -169,6 +169,11 @@ end function step_reject_controller!(integrator, controller::IController, alg) (; qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = qold end @@ -241,6 +246,11 @@ end function step_reject_controller!(integrator, cache::IControllerCache, alg) @assert cache.dtreject ≈ integrator.qold "Controller cache went out of sync with time stepping logic." + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = cache.dtreject # TODO this does not look right. end @@ -320,6 +330,11 @@ end function step_reject_controller!(integrator, controller::PIController, alg) (; q11) = integrator (; qmin, gamma) = integrator.opts + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -423,6 +438,11 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -599,6 +619,11 @@ function step_accept_controller!(integrator, controller::PIDController, alg, dt_ end function step_reject_controller!(integrator, controller::PIDController, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= integrator.qold end @@ -730,6 +755,11 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= cache.dt_factor end @@ -841,6 +871,12 @@ end function step_reject_controller!(integrator, controller::PredictiveController, alg) (; dt, success_iter, qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end @@ -948,7 +984,6 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache - if (integrator.disco_dt_set) println("using fixed dt from discontinuity handling") integrator.disco_dt_set = false diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 0f69ed6d3db..5f3b80afa7c 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,6 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock -using OrdinaryDiffEqBDF +using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner #test example discontinuous at u = 1 f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] @@ -14,23 +13,35 @@ function affect!(integrator) integrator.u[1] += 10 end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) -sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 291.292 μs (8449 allocations: 266.47 KiB) -sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 335.417 μs (10008 allocations: 311.08 KiB) -rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) # 410.291 μs (16828 allocations: 594.05 KiB) -rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) +rodas_no_disco = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) # 483.792 μs (17729 allocations: 639.31 KiB) -bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) # 245.917 μs (20703 allocations: 665.16 KiB) -bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +bdf_no_disco = solve(prob, FBDF(); callback = cb2, reltol = 1e-6) # 269.333 μs (20477 allocations: 663.80 KiB) +tsit_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# same either way for some reason? check about this +tsit_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) + +vern_disco = solve(prob, Vern7(); callback = cb, reltol = 1e-6) +# 111.125 μs (15629 allocations: 493.66 KiB) +vern_no_disco = solve(prob, Vern7(); callback = cb2, reltol = 1e-6) +# 83.666 μs (13326 allocations: 420.31 KiB) @profview for i in 1:1000 - solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) + solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end + #two discontinuity functions function f(u, p, t) if u[1] < 1 diff --git a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl index d77dc456c45..a0cd04b69bf 100644 --- a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl +++ b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl @@ -18,7 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, fac_default_gamma, get_current_adaptive_order, get_fsalfirstlast, isfirk, generic_solver_docstring, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, set_discontinuity + _process_AD_choice, LinearAliasSpecifier using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester isfirk, generic_solver_docstring using SciMLOperators: AbstractSciMLOperator diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index 3ee4ecaa422..2231de529dd 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -92,7 +92,6 @@ struct RadauIIA5{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD - is_disco::Bool end function RadauIIA5(; @@ -103,7 +102,7 @@ function RadauIIA5(; extrapolant = :dense, fast_convergence_cutoff = 1 // 5, new_W_γdt_cutoff = 1 // 5, controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true, - step_limiter! = trivial_limiter!, is_disco = false + step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -124,7 +123,6 @@ function RadauIIA5(; controller, step_limiter!, AD_choice, - is_disco ) end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index ce535649b5a..499b345d13c 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -675,10 +675,6 @@ end integrator.k[4] = z2 integrator.k[5] = z3 end - else - if alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end integrator.fsallast = f(u, p, t + dt) @@ -956,10 +952,6 @@ end integrator.k[4] .= z2 integrator.k[5] .= z3 end - else - if alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end f(fsallast, u, p, t + dt) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 7017e379fad..9ef1b32e2a7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_adaptive_order, isWmethod, isfsal, _un calculate_residuals, has_stiff_interpolation, ODEIntegrator, resize_non_user_cache!, _ode_addsteps!, full_cache, DerivativeOrderNotPossibleError, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, copyat_or_push!, set_discontinuity + _process_AD_choice, LinearAliasSpecifier, copyat_or_push! using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools: namify using MacroTools: @capture diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index c08af7f8f8c..3e2374e59ca 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -26,14 +26,13 @@ for (Alg, desc, refs, is_W) in [ step_limiter!::StepLimiter stage_limiter!::StageLimiter autodiff::AD - is_disco::Bool end function $Alg(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!, - stage_limiter! = trivial_limiter!, is_disco = false + stage_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice( autodiff, chunk_size, diff_type @@ -45,7 +44,7 @@ for (Alg, desc, refs, is_W) in [ typeof(stage_limiter!) }( linsolve, precs, step_limiter!, - stage_limiter!, AD_choice, is_disco + stage_limiter!, AD_choice ) end end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 038c06fa01c..0ad836ac446 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1384,10 +1384,6 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end if integrator.opts.calck @@ -1528,9 +1524,6 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end if integrator.opts.calck From 9d306b7cf28f2d70c6fa187ccb4723b11083f03a Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:38:14 -0600 Subject: [PATCH 05/27] add disco to BDF controller --- lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/controllers.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index f6e404018da..50c5eb33fb8 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 3be5968b186..a14a1672d38 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -112,6 +112,13 @@ function bdf_step_reject_controller!(integrator, EEst1) h = integrator.dt integrator.cache.consfailcnt += 1 integrator.cache.nconsteps = 0 + + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + if integrator.cache.consfailcnt > 1 h = h / 2 end From 2158f3b9bb2ce9aadb623f53a92ee69a95c2dd21 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:54:04 -0600 Subject: [PATCH 06/27] fix small edits --- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 2 +- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 1 - lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index d6433237802..ad0aba5f193 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -841,4 +841,4 @@ function DFBDF(; ) end -@truncate_stacktrace DFBDF \ No newline at end of file +@truncate_stacktrace DFBDF diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index bee908fe82d..59deace1b22 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1483,7 +1483,6 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( atmp, _vec(terk_tmp), _vec(uprev), _vec(u), abstol, reltol, diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index 2231de529dd..aa433958379 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -122,7 +122,7 @@ function RadauIIA5(; new_W_γdt_cutoff, controller, step_limiter!, - AD_choice, + AD_choice ) end From 2b9a5097a75769e5376f8bc20a37177bdfd23993 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:58:45 -0600 Subject: [PATCH 07/27] Update algorithms.jl --- lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index 3e2374e59ca..cc5a11dedd3 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -41,7 +41,7 @@ for (Alg, desc, refs, is_W) in [ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac), typeof(step_limiter!), - typeof(stage_limiter!) + typeof(stage_limiter!), }( linsolve, precs, step_limiter!, stage_limiter!, AD_choice From d8c1a3be88e319b318194ee6dce35e772b7a724f Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 16 Mar 2026 16:51:27 -0400 Subject: [PATCH 08/27] update disco scheme by caching problems in integrator --- lib/OrdinaryDiffEqCore/src/disco.jl | 53 +++-- .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 23 ++- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 186 ++++++++---------- 4 files changed, 130 insertions(+), 133 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 20698ade0f1..f402034a376 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -1,4 +1,4 @@ -function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to test +function set_discontinuity(u, uprev, integrator, cache) breakpointθ = find_discontinuity(u, uprev, integrator, cache) dt = integrator.dt t = integrator.t @@ -10,7 +10,6 @@ function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to te end function find_discontinuity(u, uprev, integrator, cache) - println("Finding discontinuity...") cb = integrator.opts.callback cb === nothing && return -1 isempty(cb.continuous_callbacks) && return -1 @@ -18,49 +17,45 @@ function find_discontinuity(u, uprev, integrator, cache) t = integrator.t dt = integrator.dt breakpointθ = -one(dt) - prob = nothing + idx = 1 for i in cb.continuous_callbacks if (!(i.is_discontinuity)) continue end - out_prev = nothing - out_curr = nothing - is_inplace = DiffEqBase.isinplace(i.condition, 4) - if is_inplace + if (i isa VectorContinuousCallback) out_prev = similar(u) + out_curr = similar(u) i.condition(out_prev, uprev, t, integrator) - out_curr = similar(u) i.condition(out_curr, u, t + dt, integrator) - is_inplace = true - else - out_prev = i.condition(uprev, t, integrator) - out_curr = i.condition(u, t + dt, integrator) - is_inplace = false - end - for (idx, (f0, f1)) in enumerate(zip(out_prev, out_curr)) - if (f0 * f1 < zero(f0)) - function zero_func(θ, p) + for (ind, (f0, f1)) in enumerate(zip(out_prev, out_curr)) + if (f0 * f1 < zero(f0)) u₁ = similar(u) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - if is_inplace - out = similar(u) - i.condition(out, u₁, t + θ * dt, integrator) - else - out = i.condition(u₁, t + θ * dt, integrator) + out = similar(u) + function zero_func(θ, p) + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + i.condition(out, u₁, t + θ * integrator.dt, integrator) + out[ind] end - out[idx] - end - if prob === nothing prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) - else - prob = remake(prob; f=zero_func) + sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end end - sol = solve(prob; bracket=[zero(dt), one(dt)]) + end + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + if (out_prev * out_curr < zero(out_prev)) + prob = integrator.disco_probs[idx] + sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp end end + idx += 1 end end breakpointθ diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 6234de5143b..4f4fbc62d1a 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -146,4 +146,5 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + disco_probs::Vector{IntervalNonlinearProblem} end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index b9ce19f7622..a1bdae6b409 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -81,6 +81,7 @@ function SciMLBase.__init( alias = ODEAliasSpecifier(), initializealg = DefaultInit(), rng = nothing, + disco_probs = nothing, kwargs... ) if prob isa SciMLBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm @@ -653,6 +654,26 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng + num_cb = 0 + for i in callbacks_internal.continuous_callbacks + num_cb += 1 + end + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_cb) + idx = 1 + for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) + if i.is_discontinuity && !(i isa VectorContinuousCallback) + #VCC problems handled in disco itself + u₁ = similar(u) + function zero_func(θ, p) + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + out = i.condition(u₁, t + θ * integrator.dt, integrator) + out + end + disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + disco_probs[idx] = disco_prob + end + idx+=1 + end integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), @@ -686,7 +707,7 @@ function SciMLBase.__init( isout, reeval_fsal, u_modified, reinitiailize, isdae, opts, stats, initializealg, differential_vars, - fsalfirst, fsallast, _rng + fsalfirst, fsallast, _rng, disco_probs ) if initialize_integrator diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 5f3b80afa7c..ba56111e12b 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,6 +1,7 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +#TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] u0 = [0.1] @@ -10,38 +11,22 @@ prob = ODEProblem(f, u0, tspan) #define callback condition(u, t, integrator) = u[1] - 1 function affect!(integrator) + #println("fired callback at t=$(integrator.t), u=$(integrator.u[1])") integrator.u[1] += 10 end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 291.292 μs (8449 allocations: 266.47 KiB) +# 277.833 μs (8033 allocations: 251.14 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 335.417 μs (10008 allocations: 311.08 KiB) - -rodas_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) -# 410.291 μs (16828 allocations: 594.05 KiB) -rodas_no_disco = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) -# 483.792 μs (17729 allocations: 639.31 KiB) - -bdf_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) -# 245.917 μs (20703 allocations: 665.16 KiB) -bdf_no_disco = solve(prob, FBDF(); callback = cb2, reltol = 1e-6) -# 269.333 μs (20477 allocations: 663.80 KiB) - -tsit_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) -# same either way for some reason? check about this -tsit_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) - -vern_disco = solve(prob, Vern7(); callback = cb, reltol = 1e-6) -# 111.125 μs (15629 allocations: 493.66 KiB) -vern_no_disco = solve(prob, Vern7(); callback = cb2, reltol = 1e-6) -# 83.666 μs (13326 allocations: 420.31 KiB) +# 343.041 μs (10008 allocations: 311.02 KiB) + @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end +#TEST 2: TWO DISCONTINUITIES #two discontinuity functions function f(u, p, t) if u[1] < 1 @@ -49,7 +34,7 @@ function f(u, p, t) elseif u[1] < 2 [u[1] + 0.2] # region 2: continues increasing to hit u = 2 else - [-4u[1] + 12] # region 3: after 2, moves toward u ≈ 3 + [-4u[1] + 12] end end @@ -63,28 +48,27 @@ function affect1!(integrator) #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") end cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) +cb1f = ContinuousCallback(condition1, affect1!; is_discontinuity = false) condition2(u, t, integrator) = u[1] - 2 function affect2!(integrator) #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") end cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) +cb2f = ContinuousCallback(condition2, affect2!; is_discontinuity = false) cb = CallbackSet(cb1, cb2) +cb2 = CallbackSet(cb1f, cb2f) #disco solve -sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +# 1.664 ms (43703 allocations: 1.35 MiB) #fixed order solve -sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) - -rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) - -rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) - -bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +# 1.266 ms (37019 allocations: 1.12 MiB) -bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +#TEST 3: EXPONENTIAL DISCONTINUITY # multiple exponential regions with sharp transitions function f_multi_exp!(du, u, p, t) if u[1] < 0.3 @@ -103,50 +87,32 @@ prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) #define callbacks cond_multi_1(u, t, integrator) = u[1] - 0.3 function affect_multi_1!(integrator) - println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = false) cond_multi_2(u, t, integrator) = u[1] - 0.8 function affect_multi_2!(integrator) - println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = false) cb_multi = CallbackSet(cb_multi_1, cb_multi_2) +cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve -sol_disco = solve(prob_multi, RadauIIA5(is_disco = true); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 202.834 μs (2770 allocations: 93.23 KiB) #fixed order solve -sol_no_disco = solve(prob_multi, RadauIIA5(is_disco = false); callback=cb_multi, reltol = 1e-7, abstol = 1e-9) - -# 2D system with exponential coupling and discontinuity -function f_2d_exp!(du, u, p, t) - if u[1] + u[2] < 1.0 - du[1] = 2 * exp(u[1]) - u[2] - du[2] = -3 * u[1] + 4 * exp(u[2]) - else - du[1] = u[1] - du[2] = u[2] - end -end - -u0_2d = [0.1, 0.2] -tspan_2d = (0.0, 2.0) -prob_2d = ODEProblem(f_2d_exp!, u0_2d, tspan_2d) +sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 122.875 μs (1136 allocations: 54.52 KiB) -#define callback -cond_2d(u, t, integrator) = u[1] + u[2] - 1.0 -function affect_2d!(integrator) - println("2D exponential discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") - @test 0.98 < integrator.u[1] + integrator.u[2] < 1.02 +@profview for i in 1:1000 + solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) end -cb_2d = ContinuousCallback(cond_2d, affect_2d!; is_discontinuity = true) - -#disco solve -sol_disco = solve(prob_2d, RadauIIA5(is_disco = true); callback=cb_2d, reltol=1e-8, abstol=1e-10) -#fixed order solve -sol_no_disco = solve(prob_2d, RadauIIA5(is_disco = false); callback=cb_2d, reltol = 1e-8, abstol = 1e-10) +#TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system function f_stiff_disc!(du, u, p, t) λ = p[1] # stiffness parameter @@ -164,15 +130,20 @@ prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) #define callback cond_stiff(u, t, integrator) = u[1] - 0.5 function affect_stiff!(integrator) - println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = false) #disco solve -sol_disco = solve(prob_stiff, RadauIIA5(is_disco = true); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 131.875 μs (1956 allocations: 74.03 KiB) #fixed order solve -sol_no_disco = solve(prob_stiff, RadauIIA5(is_disco = false); callback=cb_stiff, reltol = 1e-9, abstol = 1e-11) +sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 119.417 μs (1480 allocations: 59.55 KiB) + +#TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE # multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) function f_many_disc!(du, u, p, t) du[1] = u[1] + 1 # simple linear growth @@ -187,20 +158,26 @@ disc_values = [0.1 + i * 1e-6 for i = 0:4] # define callbacks for each discontinuity cbs_many = [] +cbs_many_f = [] for (i, disc_val) in enumerate(disc_values) local cond_func(u, t, integrator) = u[1] - disc_val function affect_func!(integrator) - println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") end push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) + push!(cbs_many_f, ContinuousCallback(cond_func, affect_func!; is_discontinuity = false)) end cb_many = CallbackSet(cbs_many...) +cb_many_f = CallbackSet(cbs_many_f...) #disco solve -sol_disco = solve(prob_many, RadauIIA5(is_disco = true); callback=cb_many, reltol=1e-10, abstol=1e-12) +sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) +# 111.333 μs (907 allocations: 36.94 KiB) #fixed order solve -sol_no_disco = solve(prob_many, RadauIIA5(is_disco = false); callback=cb_many, reltol=1e-10, abstol=1e-12) +sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) +# 111.666 μs (907 allocations: 36.94 KiB) +#TEST 6: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) # du[1]/dt = u[2] - u[1] @@ -228,41 +205,44 @@ function affect_dae!(integrator) #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") end cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) + +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 83.500 μs (769 allocations: 35.72 KiB) +radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 101.542 μs (1230 allocations: 48.16 KiB) + +#TEST 7: VECTOR CALLBACK +function f!(du, u, p, t) + du[1] = -u[1] + du[2] = 0.2*u[1] - 0.1*u[2] +end + +u0 = [3.0, 0.0] +tspan = (0.0, 10.0) +prob = ODEProblem(f!, u0, tspan) + +# Two event surfaces: u[1] == 2.0 and u[1] == 1.0 +function condition!(out, u, t, integrator) + out[1] = u[1] - 2.0 + out[2] = u[1] - 1.0 +end + +# Discontinuous update to the state when an event fires +function affect!(integrator, idx) + if idx == 1 + # when u[1] crosses 2, kick u[2] up (jump discontinuity) + integrator.u[2] += 5.0 + elseif idx == 2 + # when u[1] crosses 1, reset u[2] + integrator.u[2] = 0.0 + end +end +cb = VectorContinuousCallback(condition!, affect!, 2;) +cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) -radau_no_disco = solve(prob_dae, RadauIIA5(is_disco = false); callback=cb_dae, reltol=1e-8, abstol=1e-10) - #83.500 μs (769 allocations: 35.72 KiB) -radau_disco = solve(prob_dae, RadauIIA5(is_disco = true); callback=cb_dae, reltol=1e-8, abstol=1e-10) - # 119.417 μs (1273 allocations: 55.42 KiB) -rodas_no_disco = solve(prob_dae, Rodas5P(); callback = cb_dae, reltol = 1e-6) -#= SciMLBase.DEStats -Number of function 1 evaluations: 312 -Number of function 2 evaluations: 0 -Number of W matrix evaluations: 34 -Number of linear solves: 272 -Number of Jacobians created: 19 -Number of nonlinear solver iterations: 0 -Number of nonlinear solver convergence failures: 0 -Number of fixed-point solver iterations: 0 -Number of fixed-point solver convergence failures: 0 -Number of rootfind condition calls: 213 -Number of accepted steps: 19 -Number of rejected steps: 15 =# -# 98.167 μs (550 allocations: 26.92 KiB) -rodas_disco = solve(prob_dae, Rodas5P(is_disco = true); callback = cb_dae, reltol = 1e-6) -#= SciMLBase.DEStats -Number of function 1 evaluations: 312 -Number of function 2 evaluations: 0 -Number of W matrix evaluations: 34 -Number of linear solves: 272 -Number of Jacobians created: 19 -Number of nonlinear solver iterations: 0 -Number of nonlinear solver convergence failures: 0 -Number of fixed-point solver iterations: 0 -Number of fixed-point solver convergence failures: 0 -Number of rootfind condition calls: 213 -Number of accepted steps: 19 -Number of rejected steps: 15 =# -# 97.541 μs (550 allocations: 26.92 KiB) -bdf_no_disco = solve(prob_dae, FBDF(); callback = cb_dae, reltol = 1e-6) -bdf_disco = solve(prob_dae, FBDF(is_disco = true); callback = cb_dae, reltol = 1e-6) \ No newline at end of file +sol_disco = solve(prob, RadauIIA5(); callback = cb) +# 62.041 μs (849 allocations: 41.64 KiB) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) +# 37.375 μs (531 allocations: 25.23 KiB) \ No newline at end of file From a293d6b557846a71222cb2ebfbf98ad80b64726e Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 16 Mar 2026 16:56:05 -0400 Subject: [PATCH 09/27] small optimization --- lib/OrdinaryDiffEqCore/src/solve.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index a1bdae6b409..73eade112ba 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -654,11 +654,13 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng - num_cb = 0 + disco_cb_num = 0 for i in callbacks_internal.continuous_callbacks - num_cb += 1 + if i.is_discontinuity + disco_cb_num += 1 + end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, num_cb) + disco_probs = Vector{IntervalNonlinearProblem}(undef, disco_cb_num) idx = 1 for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) if i.is_discontinuity && !(i isa VectorContinuousCallback) From 6aa6dfa1522f8d4ed171f01d1841bc442473d224 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 28 Mar 2026 15:48:48 -0500 Subject: [PATCH 10/27] update disco to new approach --- lib/OrdinaryDiffEqCore/src/disco.jl | 15 ++++- .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 61 ++++++++++++++----- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 4 ++ 4 files changed, 64 insertions(+), 17 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index f402034a376..79e7eb72fa0 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -48,12 +48,21 @@ function find_discontinuity(u, uprev, integrator, cache) out_prev = i.condition(uprev, t, integrator) out_curr = i.condition(u, t + dt, integrator) if (out_prev * out_curr < zero(out_prev)) - prob = integrator.disco_probs[idx] - sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + disco_prob = integrator.disco_probs[idx] + #disco_prob = integrator.disco_prob + disco_prob.f.f.dt = integrator.dt + disco_prob.f.f.uprev = uprev + disco_prob.f.f.u = u + disco_prob.f.f.k = integrator.k + disco_prob.f.f.cache = integrator.cache + disco_prob.f.f.differential_vars = integrator.differential_vars + disco_prob.f.f.idxs = integrator.opts.save_idxs + #disco_prob.f.f.callback = i + sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp - end + end end idx += 1 end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 4f4fbc62d1a..c28e213e000 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -146,5 +146,6 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + #disco_prob::IntervalNonlinearProblem disco_probs::Vector{IntervalNonlinearProblem} end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 73eade112ba..da56cfddf6d 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,6 +16,27 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere +mutable struct zero_func_struct{uType, tType, rateType, CacheType} + #integrator_ref::IntegratorType + u₁::uType + callback::ContinuousCallback + dt::tType + uprev::uType + u::uType + k::Vector{rateType} + cache::CacheType + idxs::Union{Nothing, Vector{Int}} + differential_vars::Union{Nothing, Vector{Bool}} +end + +function (z::zero_func_struct)(θ, p) + #integrator = z.integrator_ref[]::ODEIntegrator + ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) + #ode_interpolant!(z.u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + out = z.callback.condition(z.u₁, z.dt + θ * z.dt, z) + out +end + function SciMLBase.__init( prob::Union{ SciMLBase.AbstractODEProblem, @@ -654,29 +675,38 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng - disco_cb_num = 0 + num_probs = 0 + integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity - disco_cb_num += 1 + if !(i isa VectorContinuousCallback) && i.is_discontinuity + num_probs += 1 end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, disco_cb_num) + + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) idx = 1 - for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) + for i in callbacks_internal.continuous_callbacks + if i.is_discontinuity && !(i isa VectorContinuousCallback) + u₁ = similar(u) + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) + disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + idx += 1 + end + end + #= + disco_prob = nothing + integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) + for i in callbacks_internal.continuous_callbacks if i.is_discontinuity && !(i isa VectorContinuousCallback) #VCC problems handled in disco itself u₁ = similar(u) - function zero_func(θ, p) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - out = i.condition(u₁, t + θ * integrator.dt, integrator) - out - end + #zero_func = zero_func_struct(integrator_ref, u₁, i) + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) - disco_probs[idx] = disco_prob + break end - idx+=1 - end - + end + =# integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -711,6 +741,9 @@ function SciMLBase.__init( opts, stats, initializealg, differential_vars, fsalfirst, fsallast, _rng, disco_probs ) + #if (num_probs > 0) + integrator_ref[] = integrator + #end if initialize_integrator if isdae || SciMLBase.has_initializeprob(prob.f) || diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index ba56111e12b..db7903f2e2e 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,7 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using Logging +global_logger(ConsoleLogger(stderr, Logging.Error)) #TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 @@ -19,6 +21,7 @@ cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 277.833 μs (8033 allocations: 251.14 KiB) +# curr update: 287.417 μs (8240 allocations: 258.56 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 343.041 μs (10008 allocations: 311.02 KiB) @@ -104,6 +107,7 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) # 202.834 μs (2770 allocations: 93.23 KiB) +# curr update: 238.416 μs (4426 allocations: 119.88 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 122.875 μs (1136 allocations: 54.52 KiB) From b3fa9e81e4039db94e3ff1c50981b50b6e8f8fcf Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sun, 29 Mar 2026 22:17:45 -0500 Subject: [PATCH 11/27] update benchmarks --- lib/OrdinaryDiffEqCore/src/disco.jl | 2 +- lib/OrdinaryDiffEqCore/src/solve.jl | 2 +- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 33 +++++++++++++--------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 79e7eb72fa0..29bf26c090a 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -58,7 +58,7 @@ function find_discontinuity(u, uprev, integrator, cache) disco_prob.f.f.differential_vars = integrator.differential_vars disco_prob.f.f.idxs = integrator.opts.save_idxs #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index da56cfddf6d..287e8aa8985 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -26,7 +26,7 @@ mutable struct zero_func_struct{uType, tType, rateType, CacheType} k::Vector{rateType} cache::CacheType idxs::Union{Nothing, Vector{Int}} - differential_vars::Union{Nothing, Vector{Bool}} + differential_vars::Union{Nothing, Vector{Bool}, BitVector} end function (z::zero_func_struct)(θ, p) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index db7903f2e2e..1476e4a3946 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) @@ -20,11 +20,9 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 277.833 μs (8033 allocations: 251.14 KiB) -# curr update: 287.417 μs (8240 allocations: 258.56 KiB) +# 286.125 μs (8207 allocations: 258.09 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 343.041 μs (10008 allocations: 311.02 KiB) - +# 340.292 μs (10009 allocations: 311.05 KiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end @@ -64,11 +62,14 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.664 ms (43703 allocations: 1.35 MiB) +# 1.548 ms (46763 allocations: 1.35 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 1.266 ms (37019 allocations: 1.12 MiB) +# 1.264 ms (37026 allocations: 1.13 MiB) +@profview for i in 1:1000 + solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +end #TEST 3: EXPONENTIAL DISCONTINUITY @@ -106,11 +107,10 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 202.834 μs (2770 allocations: 93.23 KiB) -# curr update: 238.416 μs (4426 allocations: 119.88 KiB) +# 195.666 μs (3834 allocations: 110.72 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) -# 122.875 μs (1136 allocations: 54.52 KiB) +# 125.583 μs (1134 allocations: 54.56 KiB) @profview for i in 1:1000 solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) @@ -141,11 +141,14 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 131.875 μs (1956 allocations: 74.03 KiB) +# 131.375 μs (2181 allocations: 78.84 KiB) #fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 119.417 μs (1480 allocations: 59.55 KiB) +@profview for i in 1:1000 + solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) +end #TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE # multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) @@ -176,11 +179,15 @@ cb_many_f = CallbackSet(cbs_many_f...) #disco solve sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 111.333 μs (907 allocations: 36.94 KiB) +# 169.541 μs (1479 allocations: 73.98 KiB) #fixed order solve sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) # 111.666 μs (907 allocations: 36.94 KiB) +@profview for i in 1:1000 + solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) +end + #TEST 6: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) @@ -214,7 +221,7 @@ cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 83.500 μs (769 allocations: 35.72 KiB) radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) -# 101.542 μs (1230 allocations: 48.16 KiB) +# 104.292 μs (1494 allocations: 53.25 KiB) #TEST 7: VECTOR CALLBACK function f!(du, u, p, t) From af7c3a50f7b91aab9a1ad0c6fa512178e6aab618 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 30 Mar 2026 19:11:37 -0500 Subject: [PATCH 12/27] further optimizations --- lib/OrdinaryDiffEqCore/src/disco.jl | 24 ++++++++++++++-------- lib/OrdinaryDiffEqCore/src/solve.jl | 10 ++++----- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 6 +++--- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 29bf26c090a..b8abb3bb633 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -16,6 +16,13 @@ function find_discontinuity(u, uprev, integrator, cache) p = integrator.p t = integrator.t dt = integrator.dt + save_idxs = integrator.opts.save_idxs + k = integrator.k + cache = integrator.cache + differential_vars = integrator.differential_vars + θlo = zero(dt) + θhi = one(dt) + bracket = [θlo, θhi] breakpointθ = -one(dt) idx = 1 for i in cb.continuous_callbacks @@ -50,15 +57,16 @@ function find_discontinuity(u, uprev, integrator, cache) if (out_prev * out_curr < zero(out_prev)) disco_prob = integrator.disco_probs[idx] #disco_prob = integrator.disco_prob - disco_prob.f.f.dt = integrator.dt - disco_prob.f.f.uprev = uprev - disco_prob.f.f.u = u - disco_prob.f.f.k = integrator.k - disco_prob.f.f.cache = integrator.cache - disco_prob.f.f.differential_vars = integrator.differential_vars - disco_prob.f.f.idxs = integrator.opts.save_idxs + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket = bracket, abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 287e8aa8985..220f7e50d6e 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,17 +16,17 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere -mutable struct zero_func_struct{uType, tType, rateType, CacheType} +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType} #integrator_ref::IntegratorType u₁::uType - callback::ContinuousCallback + callback::callbackType dt::tType uprev::uType u::uType - k::Vector{rateType} + k::kType cache::CacheType - idxs::Union{Nothing, Vector{Int}} - differential_vars::Union{Nothing, Vector{Bool}, BitVector} + idxs::idxsType + differential_vars::varsType end function (z::zero_func_struct)(θ, p) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 1476e4a3946..54b9f900a42 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -20,7 +20,7 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 286.125 μs (8207 allocations: 258.09 KiB) +# 283.292 μs (8113 allocations: 256.59 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 340.292 μs (10009 allocations: 311.05 KiB) @profview for i in 1:1000 @@ -62,7 +62,7 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.548 ms (46763 allocations: 1.35 MiB) +# 1.460 ms (41491 allocations: 1.26 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 1.264 ms (37026 allocations: 1.13 MiB) @@ -107,7 +107,7 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 195.666 μs (3834 allocations: 110.72 KiB) +# 159.125 μs (1819 allocations: 79.06 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 125.583 μs (1134 allocations: 54.56 KiB) From 3bd834fb7b706f1aa0a41f9281f55c7bbb501539 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 31 Mar 2026 10:26:31 -0500 Subject: [PATCH 13/27] fix benchmarks and merge issues --- lib/OrdinaryDiffEqCore/src/solve.jl | 1 - lib/OrdinaryDiffEqCore/test/disco_tests.jl | 28 +++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index a2f8a5efdb3..644672abdb6 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -822,7 +822,6 @@ function _ode_init( u_modified, reinitialize, isdae, opts, stats, initializealg, differential_vars, fsalfirst, fsallast, _rng, disco_probs, - fsalfirst, fsallast, _rng, W, P, sqdt, noise, c, rate_constants, QT(1) ) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 54b9f900a42..246be2015f1 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -20,9 +20,9 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 283.292 μs (8113 allocations: 256.59 KiB) +# 298.084 μs (8108 allocations: 257.11 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 340.292 μs (10009 allocations: 311.05 KiB) +# 356.708 μs (10024 allocations: 312.08 KiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end @@ -62,10 +62,10 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.460 ms (41491 allocations: 1.26 MiB) +# 1.503 ms (41672 allocations: 1.27 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 1.264 ms (37026 allocations: 1.13 MiB) +# 1.306 ms (37092 allocations: 1.13 MiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) @@ -107,10 +107,10 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 159.125 μs (1819 allocations: 79.06 KiB) +# 175.625 μs (1871 allocations: 81.55 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) -# 125.583 μs (1134 allocations: 54.56 KiB) +# 142.875 μs (1244 allocations: 59.17 KiB) @profview for i in 1:1000 solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) @@ -141,10 +141,10 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 131.375 μs (2181 allocations: 78.84 KiB) +# 149.167 μs (1819 allocations: 75.19 KiB) #fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) -# 119.417 μs (1480 allocations: 59.55 KiB) +# 138.125 μs (1565 allocations: 64.09 KiB) @profview for i in 1:1000 solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) @@ -179,10 +179,10 @@ cb_many_f = CallbackSet(cbs_many_f...) #disco solve sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 169.541 μs (1479 allocations: 73.98 KiB) +# 182.541 μs (1489 allocations: 73.64 KiB) #fixed order solve sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) -# 111.666 μs (907 allocations: 36.94 KiB) +# 121.292 μs (923 allocations: 36.78 KiB) @profview for i in 1:1000 solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) @@ -218,11 +218,11 @@ end cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) -radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) -# 83.500 μs (769 allocations: 35.72 KiB) radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) -# 104.292 μs (1494 allocations: 53.25 KiB) - +# 88.542 μs (870 allocations: 41.86 KiB) +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 73.000 μs (673 allocations: 32.05 KiB) + #TEST 7: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] From 90a9bca6eec6dcb619fc7db2bd7ea42f17f56e33 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 6 Apr 2026 19:25:15 -0500 Subject: [PATCH 14/27] update vector continuous callback handling --- lib/OrdinaryDiffEqCore/src/disco.jl | 44 ++++------ .../src/integrators/integrator_interface.jl | 2 +- lib/OrdinaryDiffEqCore/src/solve.jl | 45 ++++------ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 82 +++++++++---------- 4 files changed, 74 insertions(+), 99 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index b8abb3bb633..9f7191874e7 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -29,51 +29,43 @@ function find_discontinuity(u, uprev, integrator, cache) if (!(i.is_discontinuity)) continue end + disco_prob = integrator.disco_probs[idx] + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs if (i isa VectorContinuousCallback) + len_cb = i.len out_prev = similar(u) - out_curr = similar(u) + out_curr = similar(u) i.condition(out_prev, uprev, t, integrator) i.condition(out_curr, u, t + dt, integrator) - for (ind, (f0, f1)) in enumerate(zip(out_prev, out_curr)) - if (f0 * f1 < zero(f0)) - u₁ = similar(u) - out = similar(u) - function zero_func(θ, p) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - i.condition(out, u₁, t + θ * integrator.dt, integrator) - out[ind] - end - prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) - sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + for j in 1:len_cb + if (out_prev[j] * out_curr[j] < zero(out_prev[j])) + disco_zero.ind = j + sol = solve(disco_prob; bracket = bracket) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp - end + end end end else out_prev = i.condition(uprev, t, integrator) out_curr = i.condition(u, t + dt, integrator) if (out_prev * out_curr < zero(out_prev)) - disco_prob = integrator.disco_probs[idx] - #disco_prob = integrator.disco_prob - disco_zero = disco_prob.f.f - disco_zero.dt = dt - disco_zero.uprev = uprev - disco_zero.u = u - disco_zero.k = k - disco_zero.cache = cache - disco_zero.differential_vars = differential_vars - disco_zero.idxs = save_idxs - #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket = bracket, abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket = bracket) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp end end - idx += 1 end + idx += 1 end breakpointθ end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 8302ee706b6..4d988c21a57 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -144,7 +144,7 @@ end end end -function u_modified!(integrator::ODEIntegrator, bool::Bool) +function SciMLBase.u_modified!(integrator::ODEIntegrator, bool::Bool) return integrator.u_modified = bool end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 644672abdb6..c1d00c70cce 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,7 +16,7 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere -mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType} +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType} #integrator_ref::IntegratorType u₁::uType callback::callbackType @@ -27,14 +27,19 @@ mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsTy cache::CacheType idxs::idxsType differential_vars::varsType + ind::Int + out::outType end function (z::zero_func_struct)(θ, p) - #integrator = z.integrator_ref[]::ODEIntegrator ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) - #ode_interpolant!(z.u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - out = z.callback.condition(z.u₁, z.dt + θ * z.dt, z) - out + return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind) +end + +@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z) +@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind) + cb.condition(out, u, t, z) + return out[ind] end function SciMLBase.__init( @@ -755,38 +760,25 @@ function _ode_init( get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng + num_probs = 0 - integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) for i in callbacks_internal.continuous_callbacks - if !(i isa VectorContinuousCallback) && i.is_discontinuity + if i.is_discontinuity num_probs += 1 end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) idx = 1 for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity && !(i isa VectorContinuousCallback) + if i.is_discontinuity u₁ = similar(u) - zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) + out = i isa VectorContinuousCallback ? similar(u) : nothing + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out) disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) idx += 1 end end - #= - disco_prob = nothing - integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) - for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity && !(i isa VectorContinuousCallback) - #VCC problems handled in disco itself - u₁ = similar(u) - #zero_func = zero_func_struct(integrator_ref, u₁, i) - zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) - disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) - break - end - end - =# + integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -798,7 +790,7 @@ function _ode_init( typeof(initializealg), typeof(differential_vars), typeof(controller_cache), typeof(_rng), typeof(W), typeof(P), typeof(sqdt), - typeof(noise), typeof(c), typeof(rate_constants), + typeof(noise), typeof(c), typeof(rate_constants) }( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, @@ -825,9 +817,6 @@ function _ode_init( W, P, sqdt, noise, c, rate_constants, QT(1) ) - #if (num_probs > 0) - integrator_ref[] = integrator - #end if initialize_integrator if isdae || SciMLBase.has_initializeprob(prob.f) || diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 246be2015f1..89ca91e34c3 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -150,45 +150,7 @@ sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9 solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) end -#TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE -# multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) -function f_many_disc!(du, u, p, t) - du[1] = u[1] + 1 # simple linear growth -end - -u0_many = [0.0] -tspan_many = (0.0, 1.0) -prob_many = ODEProblem(f_many_disc!, u0_many, tspan_many) - -# create 5 discontinuities spaced 1e-6 apart -disc_values = [0.1 + i * 1e-6 for i = 0:4] - -# define callbacks for each discontinuity -cbs_many = [] -cbs_many_f = [] -for (i, disc_val) in enumerate(disc_values) - local cond_func(u, t, integrator) = u[1] - disc_val - function affect_func!(integrator) - #println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") - end - push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) - push!(cbs_many_f, ContinuousCallback(cond_func, affect_func!; is_discontinuity = false)) -end -cb_many = CallbackSet(cbs_many...) -cb_many_f = CallbackSet(cbs_many_f...) - -#disco solve -sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 182.541 μs (1489 allocations: 73.64 KiB) -#fixed order solve -sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) -# 121.292 μs (923 allocations: 36.78 KiB) - -@profview for i in 1:1000 - solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) -end - -#TEST 6: DISCONTINUOUS DAE +#TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) # du[1]/dt = u[2] - u[1] @@ -223,7 +185,7 @@ radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol= radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 73.000 μs (673 allocations: 32.05 KiB) -#TEST 7: VECTOR CALLBACK +#TEST 6: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] du[2] = 0.2*u[1] - 0.1*u[2] @@ -250,10 +212,42 @@ function affect!(integrator, idx) end end -cb = VectorContinuousCallback(condition!, affect!, 2;) -cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) +cb = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = true) +cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb) -# 62.041 μs (849 allocations: 41.64 KiB) +# 49.125 μs (664 allocations: 32.89 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) -# 37.375 μs (531 allocations: 25.23 KiB) \ No newline at end of file +# 37.375 μs (531 allocations: 25.23 KiB) + +@profview for i in 1:1000 + solve(prob, RadauIIA5(); callback = cb) +end + +#TEST 7 +function f!(du, u, p, t) + x1, x2 = u + du[1] = x2 + if x2 < 0.0 + du[2] = -x1 + 1.0 + else + du[2] = -x1 - 1.0 + end +end + +u = [1.5, 0.8] +tspan = (0.0, 2.0) +prob = ODEProblem(f!, u, tspan) + +cond(u, t, integrator) = u[2] +affect!(integrator) = nothing + +cb = ContinuousCallback(cond, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) + +sol_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) + +@profview for i in 1:1000 + solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +end From 6902dcea6b098f0548bb484e2a97ce434b444d96 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 18 Apr 2026 15:39:27 -0500 Subject: [PATCH 15/27] update tests --- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 91 +++++++++++++++------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 89ca91e34c3..99374969fcb 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) @@ -19,13 +19,20 @@ end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) -sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 298.084 μs (8108 allocations: 257.11 KiB) -sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 356.708 μs (10024 allocations: 312.08 KiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -end + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 418.584 μs (16472 allocations: 576.75 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 440.375 μs (17875 allocations: 622.09 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 59.542 μs (7248 allocations: 233.67 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 46.500 μs (7129 allocations: 226.22 KiB) #TEST 2: TWO DISCONTINUITIES #two discontinuity functions @@ -63,13 +70,18 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 1.503 ms (41672 allocations: 1.27 MiB) -#fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 1.306 ms (37092 allocations: 1.13 MiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 1.164 ms (44318 allocations: 1.52 MiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 1.306 ms (51713 allocations: 1.76 MiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 279.792 μs (34573 allocations: 1.07 MiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 266.167 μs (39024 allocations: 1.21 MiB) #TEST 3: EXPONENTIAL DISCONTINUITY @@ -108,13 +120,18 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) # 175.625 μs (1871 allocations: 81.55 KiB) -#fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 142.875 μs (1244 allocations: 59.17 KiB) -@profview for i in 1:1000 - solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) -end +sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 295.834 μs (2216 allocations: 90.70 KiB) +sol_no_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) +# 253.709 μs (1380 allocations: 74.28 KiB) + +sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 127.375 μs (1953 allocations: 87.49 KiB) +sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 95.250 μs (1499 allocations: 73.62 KiB) #TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system @@ -142,13 +159,18 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) # 149.167 μs (1819 allocations: 75.19 KiB) -#fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 138.125 μs (1565 allocations: 64.09 KiB) -@profview for i in 1:1000 - solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) -end +sol_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 204.833 μs (1517 allocations: 59.33 KiB) +sol_no_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) +# 156.500 μs (1047 allocations: 44.59 KiB) + +sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 93.833 μs (2040 allocations: 80.59 KiB) +sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 82.750 μs (1898 allocations: 72.12 KiB) #TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix @@ -185,6 +207,11 @@ radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol= radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 73.000 μs (673 allocations: 32.05 KiB) +sol_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 312.167 μs (1200 allocations: 48.73 KiB) +sol_no_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 256.792 μs (672 allocations: 32.56 KiB) + #TEST 6: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] @@ -220,9 +247,15 @@ sol_disco = solve(prob, RadauIIA5(); callback = cb) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) # 37.375 μs (531 allocations: 25.23 KiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb) +# 57.333 μs (592 allocations: 31.23 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2) +# 44.250 μs (476 allocations: 23.73 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb) +# 37.833 μs (673 allocations: 31.80 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2) +# 24.958 μs (557 allocations: 24.23 KiB) #TEST 7 function f!(du, u, p, t) @@ -245,9 +278,15 @@ affect!(integrator) = nothing cb = ContinuousCallback(cond, affect!; is_discontinuity = true) cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) -sol_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) -sol_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) -@profview for i in 1:1000 - solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 240.291 μs (1821 allocations: 71.56 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 184.625 μs (1029 allocations: 49.23 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 79.791 μs (1678 allocations: 73.85 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 55.958 μs (1259 allocations: 57.04 KiB) \ No newline at end of file From b472c7c91701282ce5d35b50f0949c431bcd78c3 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Wed, 29 Apr 2026 12:03:30 -0500 Subject: [PATCH 16/27] add steps for vern, BS, etc --- lib/OrdinaryDiffEqCore/src/disco.jl | 3 +++ lib/OrdinaryDiffEqCore/src/integrators/type.jl | 2 +- lib/OrdinaryDiffEqCore/src/solve.jl | 8 ++++++-- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 18 +++++++++++++++++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 9f7191874e7..cf20b301292 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -38,6 +38,9 @@ function find_discontinuity(u, uprev, integrator, cache) disco_zero.cache = cache disco_zero.differential_vars = differential_vars disco_zero.idxs = save_idxs + disco_zero.tprev = t + disco_zero.f = integrator.f + disco_zero.p = p if (i isa VectorContinuousCallback) len_cb = i.len out_prev = similar(u) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 766073b248e..40177eb8bb5 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -193,7 +193,7 @@ mutable struct ODEIntegrator{ fsallast::FSALType rng::RNGType #disco_prob::IntervalNonlinearProblem - disco_probs::Vector{IntervalNonlinearProblem} + disco_probs::Vector{IntervalNonlinearProblem} #should we change this? W::WType P::PType sqdt::SqdtType diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index c1d00c70cce..cb0b35bd33a 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,7 +16,7 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere -mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType} +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType, FunctionType, tType2, ParameterType} #integrator_ref::IntegratorType u₁::uType callback::callbackType @@ -29,9 +29,13 @@ mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsTy differential_vars::varsType ind::Int out::outType + f::FunctionType + tprev::tType2 + p::ParameterType end function (z::zero_func_struct)(θ, p) + _ode_addsteps!(z.k, z.tprev, z.uprev, z.u, z.dt, z.f, z.p, z.cache, false, true, false) ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind) end @@ -773,7 +777,7 @@ function _ode_init( if i.is_discontinuity u₁ = similar(u) out = i isa VectorContinuousCallback ? similar(u) : nothing - zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out) + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out, f, tprev, p) disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) idx += 1 end diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 99374969fcb..17b3f49dc72 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,7 +1,8 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) +using BenchmarkTools #TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 @@ -34,6 +35,9 @@ sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) # 46.500 μs (7129 allocations: 226.22 KiB) +sol_disco_BS5= solve(prob, BS5(); callback = cb, reltol = 1e-6) +sol_no_disco_BS5 = solve(prob, BS5(); callback = cb2, reltol = 1e-6) + #TEST 2: TWO DISCONTINUITIES #two discontinuity functions function f(u, p, t) @@ -133,6 +137,13 @@ sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abs sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 95.250 μs (1499 allocations: 73.62 KiB) +sol_disco_BS3 = solve(prob_multi, BS3(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_no_disco_BS3 = solve(prob_multi, BS3(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) + +@profview for i in 1:1000 + solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +end + #TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system function f_stiff_disc!(du, u, p, t) @@ -172,6 +183,11 @@ sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abs sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 82.750 μs (1898 allocations: 72.12 KiB) +sol_disco_BS3 = solve(prob_stiff, BS3(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 1.121 ms (12460 allocations: 595.30 KiB) +sol_no_disco_BS3 = solve(prob_stiff, BS3(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) +# 1.102 ms (12229 allocations: 582.34 KiB) + #TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) From 77ea57ef4adc708cfc31e63413403b48ea395c81 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 1 May 2026 16:02:16 -0500 Subject: [PATCH 17/27] fix rebase errors --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 3 +- lib/OrdinaryDiffEqCore/src/solve.jl | 6 +- .../test/disco_benchmarks.jl | 83 +++++++++++++++++++ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 2 +- 4 files changed, 86 insertions(+), 8 deletions(-) create mode 100644 lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 0c0a3b30879..78604eae0b5 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -104,7 +104,6 @@ function bdf_step_reject_controller!(integrator, cache, EEst1) return integrator.dt end - if integrator.cache.consfailcnt > 1 cache.consfailcnt += 1 cache.nconsteps = 0 if cache.consfailcnt > 1 @@ -505,4 +504,4 @@ function step_accept_controller!( cache.qwait -= 1 # countdown end return integrator.dt / q -end +end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index e2352a257a8..d184dad1d2f 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -694,11 +694,7 @@ Base.@constprop :aggressive function _ode_init( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, - dtpropose, disco_dt_set, tdir, eigen_est, EEst, - # TODO vvv remove these - QT(qoldinit), q11, - erracc, dtacc, - # TODO ^^^ remove these + dtpropose, disco_dt_set, tdir, eigen_est, controller_cache, success_iter, iter, saveiter, saveiter_dense, cache, diff --git a/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl b/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl new file mode 100644 index 00000000000..c762710ec29 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl @@ -0,0 +1,83 @@ +using DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK +using OrdinaryDiffEqRadau, OrdinaryDiffEqBS3 +using Logging +global_logger(ConsoleLogger(stderr, Logging.Error)) +using BenchmarkTools + + +#tests against Hairer's RADAR problems +h(p, t) = 0.5 + +# state-dependent delay: τ(t) = y(t) +function delay(p, t, u) + return u[1] +end + +# DDE: y'(t) = y(y(t)) +function f(du, u, h, p, t) + τ = u[1] + du[1] = h(p, τ) +end + +# initial condition at t = 0 (must match tspan start) +u0 = [1.0] +tspan = (1.0, 5.5) + +prob = DDEProblem(f, h, delay, u0, tspan) + +sol = solve(prob, MethodOfSteps(Tsit5())) + + +# https://dieci.math.gatech.edu/preps/DieciLopez-Fili4.pdf +# vector fields +function f1!(du, u, p, t) + x1, x2 = u + du[1] = x2 + du[2] = -x1 + 1/(1.2 - x2) +end + +function f2!(du, u, p, t) + x1, x2 = u + du[1] = x2 + du[2] = -x1 - 1/(0.8 + x2) +end + +# switching surface Σ: x2 = 0.2 +condition(u, p, t) = u[2] - 0.2 + +# mode indicator (which vector field is active) +mode = Ref(1) + +function f!(du, u, p, t) + if mode[] == 1 + f1!(du, u, p, t) + else + f2!(du, u, p, t) + end +end + +# switch dynamics when crossing Σ +function affect!(integrator) + mode[] = 2 # toggle 1 ↔ 2 +end + +cb = ContinuousCallback(condition, affect!, is_discontinuity = true;) +cb2 = ContinuousCallback(condition, affect!, is_discontinuity = false;) + +u0 = [-0.4, -0.5] +tspan = (0.0, 10.0) + +prob = ODEProblem(f!, u0, tspan) + +sol_disco_tsit5 = solve(prob, Tsit5(), callback=cb) +sol_no_disco_tsit5 = solve(prob, Tsit5(), callback=cb2) + +sol_disco_radau = solve(prob, RadauIIA5(), callback=cb) +sol_no_disco_radau = solve(prob, RadauIIA5(), callback=cb2) + +sol_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb2) + +sol_disco_bs3 = solve(prob, BS3(), callback=cb) +sol_no_disco_bs3 = solve(prob, BS3(), callback=cb2) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 17b3f49dc72..30641aabe38 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -35,7 +35,7 @@ sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) # 46.500 μs (7129 allocations: 226.22 KiB) -sol_disco_BS5= solve(prob, BS5(); callback = cb, reltol = 1e-6) +sol_disco_BS5 = solve(prob, BS5(); callback = cb, reltol = 1e-6) sol_no_disco_BS5 = solve(prob, BS5(); callback = cb2, reltol = 1e-6) #TEST 2: TWO DISCONTINUITIES From 8ac29b1e08fd017973fb2613851b3b9a550806c1 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 5 May 2026 17:23:57 -0400 Subject: [PATCH 18/27] controller fix --- lib/OrdinaryDiffEqCore/src/integrators/controllers.jl | 5 +++++ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 10 +++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 45835dbdf96..687ce6e5772 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -113,6 +113,11 @@ end end @inline function step_reject_controller!(integrator, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end step_reject_controller!(integrator, integrator.controller_cache, alg) cache = integrator.cache if hasfield(typeof(cache), :nlsolve) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 30641aabe38..9e07fde81a6 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -21,14 +21,14 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 298.084 μs (8108 allocations: 257.11 KiB) +# 294.458 μs (8082 allocations: 256.59 KiB) sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 356.708 μs (10024 allocations: 312.08 KiB) sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) -# 418.584 μs (16472 allocations: 576.75 KiB) +# 474.375 μs (16801 allocations: 592.50 KiB) sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) -# 440.375 μs (17875 allocations: 622.09 KiB) +# 509.083 μs (18240 allocations: 639.33 KiB) sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) # 59.542 μs (7248 allocations: 233.67 KiB) @@ -72,9 +72,9 @@ cb = CallbackSet(cb1, cb2) cb2 = CallbackSet(cb1f, cb2f) #disco solve -sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 1.503 ms (41672 allocations: 1.27 MiB) -sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 1.306 ms (37092 allocations: 1.13 MiB) sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) From 5b1069d0594c6814683356aaa889a62284ad31cd Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 5 May 2026 18:04:58 -0400 Subject: [PATCH 19/27] fix bug --- lib/OrdinaryDiffEqCore/src/integrators/controllers.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 687ce6e5772..e6178c64995 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -325,7 +325,6 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q) end function step_reject_controller!(integrator, cache::IControllerCache, alg) - @assert cache.dtreject ≈ integrator.qold "Controller cache went out of sync with time stepping logic." disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 integrator.dt = disco_dt From e9d4d3c29afb49a3faee60c1a9c068f95a0900a2 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Wed, 6 May 2026 14:47:12 -0400 Subject: [PATCH 20/27] controller edits, still WIP --- lib/OrdinaryDiffEqCore/src/disco.jl | 2 +- .../src/integrators/controllers.jl | 85 +++++++++++++------ .../src/integrators/type.jl | 1 - lib/OrdinaryDiffEqCore/src/solve.jl | 7 +- .../test/disco_benchmarks.jl | 30 +++---- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 36 ++++---- 6 files changed, 95 insertions(+), 66 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index cf20b301292..3c2190c3927 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -26,7 +26,7 @@ function find_discontinuity(u, uprev, integrator, cache) breakpointθ = -one(dt) idx = 1 for i in cb.continuous_callbacks - if (!(i.is_discontinuity)) + if (!(i.maybe_discontinuity)) continue end disco_prob = integrator.disco_probs[idx] diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index e6178c64995..de5d0e00e59 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -113,10 +113,13 @@ end end @inline function step_reject_controller!(integrator, alg) - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt + disco_handling = integrator.controller_cache.controller.disco_handling + if disco_handling + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end end step_reject_controller!(integrator, integrator.controller_cache, alg) cache = integrator.cache @@ -194,10 +197,12 @@ on the algorithm cache continues to work. mutable struct DummyControllerCache{T, C} <: AbstractControllerCache EEst::T cache::C + disco_handling::Bool end function setup_controller_cache(alg, cache, controller::DummyController, ::Type{E}) where {E} - return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache) + disco_handling = false + return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache, disco_handling) end # Algorithms with integrated controllers (BDF, Nordsieck, …) only define their @@ -255,9 +260,11 @@ struct IController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T + disco_handling::Bool end function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) + disco_handling = false return IController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -265,6 +272,7 @@ function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // gamma, qsteady_min, qsteady_max, + disco_handling ) end @@ -272,7 +280,7 @@ function IController(alg; kwargs...) return IController(Float64, alg; kwargs...) end -function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing) +function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, disco_handling = nothing) return IController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -280,6 +288,7 @@ function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, + disco_handling === nothing ? false : disco_handling ) end @@ -325,10 +334,13 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q) end function step_reject_controller!(integrator, cache::IControllerCache, alg) - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt + disco_handling = integrator.controller_cache.controller.disco_handling + if disco_handling + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end end return integrator.dt = cache.dtreject # TODO this does not look right. end @@ -387,9 +399,11 @@ mutable struct PIController{T} <: AbstractController # TODO remove the mutable o qsteady_min::T qsteady_max::T qoldinit::T + disco_handling::Bool end function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5, qoldinit = 1 // 10^4) + disco_handling = false return PIController{typeof(beta1)}( beta1, beta2, @@ -400,6 +414,7 @@ function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, q qsteady_min, qsteady_max, qoldinit, + disco_handling ) end @@ -407,7 +422,7 @@ function PIController(alg; kwargs...) return PIController(Float64, alg; kwargs...) end -function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing) +function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing, disco_handling = nothing) beta2 = beta2 === nothing ? beta2_default(alg) : beta2 beta1 = beta1 === nothing ? beta1_default(alg, beta2) : beta1 qoldinit = qoldinit === nothing ? 1 // 10^4 : qoldinit @@ -420,7 +435,8 @@ function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, - qoldinit, + qoldinit, + disco_handling === nothing ? false : disco_handling ) end @@ -475,10 +491,15 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt + disco_handling = integrator.controller_cache.controller.disco_handling + #tsit comes here + if disco_handling + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + println("using disco set dt") + integrator.dt = disco_dt + return integrator.dt + end end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -567,18 +588,21 @@ struct PIDController{QT, Limiter} <: AbstractController limiter::Limiter # limiter of the dt factor (before clipping) qsteady_min::QT qsteady_max::QT + disco_handling::Bool end @inline default_dt_factor_limiter(x) = one(x) + atan(x - one(x)) function PIDController(beta1::Real, beta2::Real, beta3::Real = zero(beta1); accept_safety = 0.81, limiter = default_dt_factor_limiter, qsteady_min = 1 // 1, qsteady_max = 6 // 5) beta = map(float, promote(beta1, beta2, beta3)) + disco_handling = false return PIDController{typeof(beta1), typeof(limiter)}( beta, accept_safety, limiter, qsteady_min, qsteady_max, + disco_handling ) end @@ -601,6 +625,7 @@ function PIDController(QT, alg; beta = nothing, accept_safety = 0.81, limiter = limiter, QT(qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min), QT(qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max), + disco_handling === nothing ? false : disco_handling ) end @@ -698,10 +723,13 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt + disco_handling = integrator.controller_cache.controller.disco_handling + if disco_handling + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end end return integrator.dt *= cache.dt_factor end @@ -774,9 +802,11 @@ struct PredictiveController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T + disco_handling::Bool end function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) + disco_handling = false return PredictiveController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -784,6 +814,7 @@ function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first gamma, qsteady_min, qsteady_max, + disco_handling ) end @@ -791,7 +822,7 @@ function PredictiveController(alg; kwargs...) return PredictiveController(Float64, alg; kwargs...) end -function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing) +function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, disco_handling = nothing) return PredictiveController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -799,6 +830,7 @@ function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_firs gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, + disco_handling === nothing ? false : disco_handling ) end @@ -888,10 +920,13 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache - if (integrator.disco_dt_set) - println("using fixed dt from discontinuity handling") - integrator.disco_dt_set = false - return integrator.dt + disco_handling = integrator.controller_cache.controller.disco_handling + if disco_handling + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end end return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 1aa548feffa..2c2547ce1d5 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -143,7 +143,6 @@ mutable struct ODEIntegrator{ dtcache::tType dtchangeable::Bool dtpropose::tType - disco_dt_set::Bool tdir::tdirType eigen_est::eigenType controller_cache::CC diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index d184dad1d2f..d9fa6dc4808 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -87,7 +87,6 @@ Base.@constprop :aggressive function _ode_init( save_everystep = isempty(saveat), save_on = true, save_discretes = true, - disco_dt_set = false, save_start = save_everystep || isempty(saveat) || saveat isa Number || prob.tspan[1] in saveat, save_end = nothing, @@ -660,14 +659,14 @@ Base.@constprop :aggressive function _ode_init( num_probs = 0 for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity + if i.maybe_discontinuity num_probs += 1 end end disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) idx = 1 for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity + if i.maybe_discontinuity u₁ = similar(u) out = i isa VectorContinuousCallback ? similar(u) : nothing zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out, f, tprev, p) @@ -694,7 +693,7 @@ Base.@constprop :aggressive function _ode_init( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, - dtpropose, disco_dt_set, tdir, eigen_est, + dtpropose, tdir, eigen_est, controller_cache, success_iter, iter, saveiter, saveiter_dense, cache, diff --git a/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl b/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl index c762710ec29..61d9bf0ded3 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl @@ -1,36 +1,32 @@ -using DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK -using OrdinaryDiffEqRadau, OrdinaryDiffEqBS3 +using DiffEqDevTools, Test, LinearAlgebra, DelayDiffEq +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqFIRK, OrdinaryDiffEqBDF using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) using BenchmarkTools - #tests against Hairer's RADAR problems -h(p, t) = 0.5 +h(p, t; idxs = nothing) = 0.5 # state-dependent delay: τ(t) = y(t) -function delay(p, t, u) +function delay(u, p, t) return u[1] end -# DDE: y'(t) = y(y(t)) function f(du, u, h, p, t) - τ = u[1] - du[1] = h(p, τ) + τ = delay(u, p, t) + du[1] = h(p, t - τ; idxs = 1) end -# initial condition at t = 0 (must match tspan start) +# initial condition at t = 0 u0 = [1.0] -tspan = (1.0, 5.5) - -prob = DDEProblem(f, h, delay, u0, tspan) - +tspan = (0.0, 10.0) +p = nothing +prob = DDEProblem(f, u0, h, tspan, p; dependent_lags = (delay,)) sol = solve(prob, MethodOfSteps(Tsit5())) - -# https://dieci.math.gatech.edu/preps/DieciLopez-Fili4.pdf +# https://dieci.math.gatech.edu/preps/DieciLopez-Fili4.pdf # vector fields +# BELOW IS WRONG NEED TO FIX function f1!(du, u, p, t) x1, x2 = u du[1] = x2 @@ -68,7 +64,7 @@ cb2 = ContinuousCallback(condition, affect!, is_discontinuity = false;) u0 = [-0.4, -0.5] tspan = (0.0, 10.0) -prob = ODEProblem(f!, u0, tspan) +prob = ODEProblem(f!, u0, tspan, p) sol_disco_tsit5 = solve(prob, Tsit5(), callback=cb) sol_no_disco_tsit5 = solve(prob, Tsit5(), callback=cb2) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 9e07fde81a6..e9c9196f8e3 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -17,8 +17,8 @@ function affect!(integrator) #println("fired callback at t=$(integrator.t), u=$(integrator.u[1])") integrator.u[1] += 10 end -cb = ContinuousCallback(condition, affect!; is_discontinuity = true) -cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) +cb = ContinuousCallback(condition, affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(condition, affect!; maybe_discontinuity = false) sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 294.458 μs (8082 allocations: 256.59 KiB) @@ -59,15 +59,15 @@ condition1(u, t, integrator) = u[1] - 1 function affect1!(integrator) #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") end -cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) -cb1f = ContinuousCallback(condition1, affect1!; is_discontinuity = false) +cb1 = ContinuousCallback(condition1, affect1!; maybe_discontinuity = true) +cb1f = ContinuousCallback(condition1, affect1!; maybe_discontinuity = false) condition2(u, t, integrator) = u[1] - 2 function affect2!(integrator) #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") end -cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) -cb2f = ContinuousCallback(condition2, affect2!; is_discontinuity = false) +cb2 = ContinuousCallback(condition2, affect2!; maybe_discontinuity = true) +cb2f = ContinuousCallback(condition2, affect2!; maybe_discontinuity = false) cb = CallbackSet(cb1, cb2) cb2 = CallbackSet(cb1f, cb2f) @@ -109,15 +109,15 @@ cond_multi_1(u, t, integrator) = u[1] - 0.3 function affect_multi_1!(integrator) #println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end -cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) -cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = false) +cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; maybe_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; maybe_discontinuity = false) cond_multi_2(u, t, integrator) = u[1] - 0.8 function affect_multi_2!(integrator) #println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end -cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) -cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = false) +cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; maybe_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; maybe_discontinuity = false) cb_multi = CallbackSet(cb_multi_1, cb_multi_2) cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) @@ -164,8 +164,8 @@ cond_stiff(u, t, integrator) = u[1] - 0.5 function affect_stiff!(integrator) #println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") end -cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) -cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = false) +cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; maybe_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; maybe_discontinuity = false) #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) @@ -215,8 +215,8 @@ cond_dae(u, t, integrator) = u[1] - 0.5 function affect_dae!(integrator) #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") end -cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) -cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) +cb_dae = ContinuousCallback(cond_dae, affect_dae!; maybe_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, affect_dae!; maybe_discontinuity = false) radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) # 88.542 μs (870 allocations: 41.86 KiB) @@ -255,8 +255,8 @@ function affect!(integrator, idx) end end -cb = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = true) -cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) +cb = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = true) +cb2 = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb) # 49.125 μs (664 allocations: 32.89 KiB) @@ -291,8 +291,8 @@ prob = ODEProblem(f!, u, tspan) cond(u, t, integrator) = u[2] affect!(integrator) = nothing -cb = ContinuousCallback(cond, affect!; is_discontinuity = true) -cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) +cb = ContinuousCallback(cond, affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(cond, affect!; maybe_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-8, abstol = 1e-10) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) From fc9a970ff7aa07dee3c824e35cd76c19d1713c14 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 11:17:19 -0400 Subject: [PATCH 21/27] docstring and renaming --- .../src/integrators/controllers.jl | 87 +++++++++++-------- 1 file changed, 53 insertions(+), 34 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index de5d0e00e59..09300d02042 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -113,8 +113,8 @@ end end @inline function step_reject_controller!(integrator, alg) - disco_handling = integrator.controller_cache.controller.disco_handling - if disco_handling + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 integrator.dt = disco_dt @@ -193,16 +193,20 @@ Controller cache used by algorithms that manage step-size selection themselves (BDF, Nordsieck, Leaping, …). Holds the scalar error estimate exposed through `get_EEst(integrator)` and a reference to the algorithm cache so existing dispatch on the algorithm cache continues to work. +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. """ mutable struct DummyControllerCache{T, C} <: AbstractControllerCache EEst::T cache::C - disco_handling::Bool + discontinuity_detection::Bool end function setup_controller_cache(alg, cache, controller::DummyController, ::Type{E}) where {E} - disco_handling = false - return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache, disco_handling) + discontinuity_detection = false + return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache, discontinuity_detection) end # Algorithms with integrated controllers (BDF, Nordsieck, …) only define their @@ -246,6 +250,10 @@ the interval `[qmin, qmax]`. A step will be accepted whenever the estimated error `get_EEst(integrator)` is less than or equal to unity. Otherwise, the step is rejected and re-tried with the predicted step size. +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. ## References @@ -260,11 +268,11 @@ struct IController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T - disco_handling::Bool + discontinuity_detection::Bool end function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) - disco_handling = false + discontinuity_detection = false return IController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -272,7 +280,7 @@ function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // gamma, qsteady_min, qsteady_max, - disco_handling + discontinuity_detection ) end @@ -280,7 +288,7 @@ function IController(alg; kwargs...) return IController(Float64, alg; kwargs...) end -function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, disco_handling = nothing) +function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, discontinuity_detection = nothing) return IController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -288,7 +296,7 @@ function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, - disco_handling === nothing ? false : disco_handling + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -334,8 +342,8 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q) end function step_reject_controller!(integrator, cache::IControllerCache, alg) - disco_handling = integrator.controller_cache.controller.disco_handling - if disco_handling + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 integrator.dt = disco_dt @@ -373,7 +381,10 @@ the interval `[qmin, qmax]`. A step will be accepted whenever the estimated error `get_EEst(integrator)` is less than or equal to unity. Otherwise, the step is rejected and re-tried with the predicted step size. - +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. !!! note The coefficients `beta1, beta2` are not scaled by the order of the method, @@ -399,11 +410,11 @@ mutable struct PIController{T} <: AbstractController # TODO remove the mutable o qsteady_min::T qsteady_max::T qoldinit::T - disco_handling::Bool + discontinuity_detection::Bool end function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5, qoldinit = 1 // 10^4) - disco_handling = false + discontinuity_detection = false return PIController{typeof(beta1)}( beta1, beta2, @@ -414,7 +425,7 @@ function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, q qsteady_min, qsteady_max, qoldinit, - disco_handling + discontinuity_detection ) end @@ -422,7 +433,7 @@ function PIController(alg; kwargs...) return PIController(Float64, alg; kwargs...) end -function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing, disco_handling = nothing) +function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing, discontinuity_detection = nothing) beta2 = beta2 === nothing ? beta2_default(alg) : beta2 beta1 = beta1 === nothing ? beta1_default(alg, beta2) : beta1 qoldinit = qoldinit === nothing ? 1 // 10^4 : qoldinit @@ -436,7 +447,7 @@ function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, qoldinit, - disco_handling === nothing ? false : disco_handling + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -491,9 +502,9 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller - disco_handling = integrator.controller_cache.controller.disco_handling + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection #tsit comes here - if disco_handling + if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 println("using disco set dt") @@ -553,6 +564,11 @@ Some standard controller parameters suggested in the literature are | H211PI | `1//6` | `1//6` | `0` | | H312PID | `1//18` | `1//9` | `1//18` | +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. + !!! note In contrast to the [`PIController`](@ref), the coefficients `beta1, beta2, beta3` @@ -588,21 +604,21 @@ struct PIDController{QT, Limiter} <: AbstractController limiter::Limiter # limiter of the dt factor (before clipping) qsteady_min::QT qsteady_max::QT - disco_handling::Bool + discontinuity_detection::Bool end @inline default_dt_factor_limiter(x) = one(x) + atan(x - one(x)) function PIDController(beta1::Real, beta2::Real, beta3::Real = zero(beta1); accept_safety = 0.81, limiter = default_dt_factor_limiter, qsteady_min = 1 // 1, qsteady_max = 6 // 5) beta = map(float, promote(beta1, beta2, beta3)) - disco_handling = false + discontinuity_detection = false return PIDController{typeof(beta1), typeof(limiter)}( beta, accept_safety, limiter, qsteady_min, qsteady_max, - disco_handling + discontinuity_detection ) end @@ -625,7 +641,7 @@ function PIDController(QT, alg; beta = nothing, accept_safety = 0.81, limiter = limiter, QT(qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min), QT(qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max), - disco_handling === nothing ? false : disco_handling + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -723,8 +739,8 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) - disco_handling = integrator.controller_cache.controller.disco_handling - if disco_handling + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 integrator.dt = disco_dt @@ -786,7 +802,10 @@ integrator.dt / qacc ``` When it rejects, it's the same as the [`IController`](@ref): - +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. ```julia if integrator.success_iter == 0 integrator.dt *= 0.1 @@ -802,11 +821,11 @@ struct PredictiveController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T - disco_handling::Bool + discontinuity_detection::Bool end function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) - disco_handling = false + discontinuity_detection = false return PredictiveController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -814,7 +833,7 @@ function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first gamma, qsteady_min, qsteady_max, - disco_handling + discontinuity_detection ) end @@ -822,7 +841,7 @@ function PredictiveController(alg; kwargs...) return PredictiveController(Float64, alg; kwargs...) end -function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, disco_handling = nothing) +function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, discontinuity_detection = nothing) return PredictiveController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -830,7 +849,7 @@ function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_firs gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, - disco_handling === nothing ? false : disco_handling + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -920,8 +939,8 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache - disco_handling = integrator.controller_cache.controller.disco_handling - if disco_handling + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 integrator.dt = disco_dt From b9584cfec040ba2585407fdceaebde5cfbe7d889 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 12:31:29 -0400 Subject: [PATCH 22/27] some edits --- lib/OrdinaryDiffEqCore/src/integrators/controllers.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 09300d02042..2ab95ef0e2d 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -113,14 +113,6 @@ end end @inline function step_reject_controller!(integrator, alg) - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection - if discontinuity_detection - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt - end - end step_reject_controller!(integrator, integrator.controller_cache, alg) cache = integrator.cache if hasfield(typeof(cache), :nlsolve) @@ -503,7 +495,6 @@ function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection - #tsit comes here if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 From 300e59e3f96cfd611b84aaa62100d34e1851514d Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 15:20:04 -0400 Subject: [PATCH 23/27] tests --- .../src/integrators/controllers.jl | 1 - .../test/disco_benchmarks.jl | 79 ------- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 200 ++++++++---------- 3 files changed, 92 insertions(+), 188 deletions(-) delete mode 100644 lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 2ab95ef0e2d..41f9b206b60 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -498,7 +498,6 @@ function step_reject_controller!(integrator, cache::PIControllerCache, alg) if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 - println("using disco set dt") integrator.dt = disco_dt return integrator.dt end diff --git a/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl b/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl deleted file mode 100644 index 61d9bf0ded3..00000000000 --- a/lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl +++ /dev/null @@ -1,79 +0,0 @@ -using DiffEqDevTools, Test, LinearAlgebra, DelayDiffEq -using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqFIRK, OrdinaryDiffEqBDF -using Logging -global_logger(ConsoleLogger(stderr, Logging.Error)) -using BenchmarkTools - -#tests against Hairer's RADAR problems -h(p, t; idxs = nothing) = 0.5 - -# state-dependent delay: τ(t) = y(t) -function delay(u, p, t) - return u[1] -end - -function f(du, u, h, p, t) - τ = delay(u, p, t) - du[1] = h(p, t - τ; idxs = 1) -end - -# initial condition at t = 0 -u0 = [1.0] -tspan = (0.0, 10.0) -p = nothing -prob = DDEProblem(f, u0, h, tspan, p; dependent_lags = (delay,)) -sol = solve(prob, MethodOfSteps(Tsit5())) - -# https://dieci.math.gatech.edu/preps/DieciLopez-Fili4.pdf -# vector fields -# BELOW IS WRONG NEED TO FIX -function f1!(du, u, p, t) - x1, x2 = u - du[1] = x2 - du[2] = -x1 + 1/(1.2 - x2) -end - -function f2!(du, u, p, t) - x1, x2 = u - du[1] = x2 - du[2] = -x1 - 1/(0.8 + x2) -end - -# switching surface Σ: x2 = 0.2 -condition(u, p, t) = u[2] - 0.2 - -# mode indicator (which vector field is active) -mode = Ref(1) - -function f!(du, u, p, t) - if mode[] == 1 - f1!(du, u, p, t) - else - f2!(du, u, p, t) - end -end - -# switch dynamics when crossing Σ -function affect!(integrator) - mode[] = 2 # toggle 1 ↔ 2 -end - -cb = ContinuousCallback(condition, affect!, is_discontinuity = true;) -cb2 = ContinuousCallback(condition, affect!, is_discontinuity = false;) - -u0 = [-0.4, -0.5] -tspan = (0.0, 10.0) - -prob = ODEProblem(f!, u0, tspan, p) - -sol_disco_tsit5 = solve(prob, Tsit5(), callback=cb) -sol_no_disco_tsit5 = solve(prob, Tsit5(), callback=cb2) - -sol_disco_radau = solve(prob, RadauIIA5(), callback=cb) -sol_no_disco_radau = solve(prob, RadauIIA5(), callback=cb2) - -sol_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb) -sol_no_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb2) - -sol_disco_bs3 = solve(prob, BS3(), callback=cb) -sol_no_disco_bs3 = solve(prob, BS3(), callback=cb2) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index e9c9196f8e3..4f9e94b6e62 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,9 +1,16 @@ -using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK +using OrdinaryDiffEqCore, DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqFIRK using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) using BenchmarkTools +predictive_disco_controller(alg) = OrdinaryDiffEqCore.PredictiveController(alg; discontinuity_detection = true) +PI_disco_controller(alg) = OrdinaryDiffEqCore.PIController(alg; discontinuity_detection = true) + +function default_affect!(integrator) + nothing +end + #TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] @@ -13,38 +20,37 @@ prob = ODEProblem(f, u0, tspan) #define callback condition(u, t, integrator) = u[1] - 1 -function affect!(integrator) - #println("fired callback at t=$(integrator.t), u=$(integrator.u[1])") - integrator.u[1] += 10 -end -cb = ContinuousCallback(condition, affect!; maybe_discontinuity = true) -cb2 = ContinuousCallback(condition, affect!; maybe_discontinuity = false) +cb = ContinuousCallback(condition, default_affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(condition, default_affect!; maybe_discontinuity = false) -sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6, controller = predictive_disco_controller(RadauIIA5())) # 294.458 μs (8082 allocations: 256.59 KiB) sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 356.708 μs (10024 allocations: 312.08 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject -sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Rodas5P())) # 474.375 μs (16801 allocations: 592.50 KiB) sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) # 509.083 μs (18240 allocations: 639.33 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject -sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Tsit5())) # 59.542 μs (7248 allocations: 233.67 KiB) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) # 46.500 μs (7129 allocations: 226.22 KiB) - -sol_disco_BS5 = solve(prob, BS5(); callback = cb, reltol = 1e-6) -sol_no_disco_BS5 = solve(prob, BS5(); callback = cb2, reltol = 1e-6) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject #TEST 2: TWO DISCONTINUITIES #two discontinuity functions function f(u, p, t) if u[1] < 1 - [2u[1]] # region 1: grows to hit u = 1 + [2u[1]] elseif u[1] < 2 - [u[1] + 0.2] # region 2: continues increasing to hit u = 2 + [u[1] + 0.2] else [-4u[1] + 12] end @@ -56,47 +62,38 @@ prob = ODEProblem(f, u0, tspan) #define callbacks condition1(u, t, integrator) = u[1] - 1 -function affect1!(integrator) - #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") -end -cb1 = ContinuousCallback(condition1, affect1!; maybe_discontinuity = true) -cb1f = ContinuousCallback(condition1, affect1!; maybe_discontinuity = false) +cb1 = ContinuousCallback(condition1, default_affect!; maybe_discontinuity = true) +cb1f = ContinuousCallback(condition1, default_affect!; maybe_discontinuity = false) condition2(u, t, integrator) = u[1] - 2 -function affect2!(integrator) - #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") -end -cb2 = ContinuousCallback(condition2, affect2!; maybe_discontinuity = true) -cb2f = ContinuousCallback(condition2, affect2!; maybe_discontinuity = false) +cb2 = ContinuousCallback(condition2, default_affect!; maybe_discontinuity = true) +cb2f = ContinuousCallback(condition2, default_affect!; maybe_discontinuity = false) cb = CallbackSet(cb1, cb2) cb2 = CallbackSet(cb1f, cb2f) -#disco solve -sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.503 ms (41672 allocations: 1.27 MiB) -sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 1.306 ms (37092 allocations: 1.13 MiB) - -sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Rodas5P())) # 1.164 ms (44318 allocations: 1.52 MiB) sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) # 1.306 ms (51713 allocations: 1.76 MiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject -sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Tsit5())) # 279.792 μs (34573 allocations: 1.07 MiB) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) # 266.167 μs (39024 allocations: 1.21 MiB) - +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject #TEST 3: EXPONENTIAL DISCONTINUITY # multiple exponential regions with sharp transitions function f_multi_exp!(du, u, p, t) if u[1] < 0.3 - du[1] = 3 * exp(3 * u[1]) # very steep exponential + du[1] = 3 * exp(3 * u[1]) elseif u[1] < 0.8 - du[1] = exp(u[1]) # slower exponential + du[1] = exp(u[1]) else - du[1] = u[1] # linear + du[1] = u[1] end end @@ -106,50 +103,47 @@ prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) #define callbacks cond_multi_1(u, t, integrator) = u[1] - 0.3 -function affect_multi_1!(integrator) - #println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") -end -cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; maybe_discontinuity = true) -cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; maybe_discontinuity = false) +cb_multi_1 = ContinuousCallback(cond_multi_1, default_affect!; maybe_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, default_affect!; maybe_discontinuity = false) cond_multi_2(u, t, integrator) = u[1] - 0.8 -function affect_multi_2!(integrator) - #println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") -end -cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; maybe_discontinuity = true) -cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; maybe_discontinuity = false) +cb_multi_2 = ContinuousCallback(cond_multi_2, default_affect!; maybe_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, default_affect!; maybe_discontinuity = false) + cb_multi = CallbackSet(cb_multi_1, cb_multi_2) cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve -sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_disco_radau = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = predictive_disco_controller(RadauIIA5())) # 175.625 μs (1871 allocations: 81.55 KiB) -sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +sol_no_disco_radau = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 142.875 μs (1244 allocations: 59.17 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject -sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = PI_disco_controller(Rodas5P())) # 295.834 μs (2216 allocations: 90.70 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success sol_no_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) # 253.709 μs (1380 allocations: 74.28 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject -sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = PI_disco_controller(Tsit5())) # 127.375 μs (1953 allocations: 87.49 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 95.250 μs (1499 allocations: 73.62 KiB) - -sol_disco_BS3 = solve(prob_multi, BS3(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -sol_no_disco_BS3 = solve(prob_multi, BS3(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) - -@profview for i in 1:1000 - solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -end +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject #TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system +@test sol_disco_tsit5.retcode == ReturnCode.Success function f_stiff_disc!(du, u, p, t) - λ = p[1] # stiffness parameter + λ = p[1] # stiffness parameter if u[1] < 0.5 - du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing + du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing else du[1] = u[1] end @@ -161,32 +155,23 @@ prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) #define callback cond_stiff(u, t, integrator) = u[1] - 0.5 -function affect_stiff!(integrator) - #println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") -end -cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; maybe_discontinuity = true) -cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; maybe_discontinuity = false) +cb_stiff = ContinuousCallback(cond_stiff, default_affect!; maybe_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, default_affect!; maybe_discontinuity = false) #disco solve -sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +sol_disco_radau = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11, controller = predictive_disco_controller(RadauIIA5())) # 149.167 μs (1819 allocations: 75.19 KiB) -sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +sol_no_disco_radau = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 138.125 μs (1565 allocations: 64.09 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject -sol_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 204.833 μs (1517 allocations: 59.33 KiB) -sol_no_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) -# 156.500 μs (1047 allocations: 44.59 KiB) - -sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11, controller = PI_disco_controller(Tsit5())) # 93.833 μs (2040 allocations: 80.59 KiB) sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 82.750 μs (1898 allocations: 72.12 KiB) - -sol_disco_BS3 = solve(prob_stiff, BS3(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 1.121 ms (12460 allocations: 595.30 KiB) -sol_no_disco_BS3 = solve(prob_stiff, BS3(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) -# 1.102 ms (12229 allocations: 582.34 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject #TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix @@ -196,14 +181,14 @@ sol_no_disco_BS3 = solve(prob_stiff, BS3(); callback=cb_stiff_f, reltol=1e-9, ab function f_dae_disc!(du, u, p, t) if u[1] < 0.5 du[1] = 2 * u[2] - u[1] - du[2] = u[1] + u[2] - 1 # algebraic constraint + du[2] = u[1] + u[2] - 1 # algebraic constraint else du[1] = -u[1] + u[2] - du[2] = u[1] + u[2] - 1 # algebraic constraint + du[2] = u[1] + u[2] - 1 end end -u0_dae = [0.2, 0.8] # consistent with constraint u[1] + u[2] = 1 +u0_dae = [0.2, 0.8] tspan_dae = (0.0, 2.0) M_dae = [1.0 0.0; 0.0 0.0] @@ -212,21 +197,15 @@ f_dae_func = ODEFunction(f_dae_disc!; mass_matrix=M_dae) prob_dae = ODEProblem(f_dae_func, u0_dae, tspan_dae) cond_dae(u, t, integrator) = u[1] - 0.5 -function affect_dae!(integrator) - #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") -end -cb_dae = ContinuousCallback(cond_dae, affect_dae!; maybe_discontinuity = true) -cb_daef = ContinuousCallback(cond_dae, affect_dae!; maybe_discontinuity = false) +cb_dae = ContinuousCallback(cond_dae, default_affect!; maybe_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, default_affect!; maybe_discontinuity = false) -radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10, controller = predictive_disco_controller(RadauIIA5())) # 88.542 μs (870 allocations: 41.86 KiB) radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 73.000 μs (673 allocations: 32.05 KiB) - -sol_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_dae, reltol=1e-8, abstol=1e-10) -# 312.167 μs (1200 allocations: 48.73 KiB) -sol_no_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_daef, reltol=1e-8, abstol=1e-10) -# 256.792 μs (672 allocations: 32.56 KiB) +@test radau_disco.retcode == ReturnCode.Success +@test radau_disco.stats.nreject <= radau_no_disco.stats.nreject #TEST 6: VECTOR CALLBACK function f!(du, u, p, t) @@ -238,7 +217,7 @@ u0 = [3.0, 0.0] tspan = (0.0, 10.0) prob = ODEProblem(f!, u0, tspan) -# Two event surfaces: u[1] == 2.0 and u[1] == 1.0 +# u[1] == 2.0 and u[1] == 1.0 function condition!(out, u, t, integrator) out[1] = u[1] - 2.0 out[2] = u[1] - 1.0 @@ -258,20 +237,26 @@ end cb = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = true) cb2 = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = false) -sol_disco = solve(prob, RadauIIA5(); callback = cb) +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, controller = predictive_disco_controller(RadauIIA5())) # 49.125 μs (664 allocations: 32.89 KiB) -sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2) # 37.375 μs (531 allocations: 25.23 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject -sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb) +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, controller = PI_disco_controller(Rodas5P())) # 57.333 μs (592 allocations: 31.23 KiB) sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2) # 44.250 μs (476 allocations: 23.73 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject -sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb) +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, controller = PI_disco_controller(Tsit5())) # 37.833 μs (673 allocations: 31.80 KiB) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2) # 24.958 μs (557 allocations: 24.23 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject #TEST 7 function f!(du, u, p, t) @@ -289,20 +274,19 @@ tspan = (0.0, 2.0) prob = ODEProblem(f!, u, tspan) cond(u, t, integrator) = u[2] -affect!(integrator) = nothing - -cb = ContinuousCallback(cond, affect!; maybe_discontinuity = true) -cb2 = ContinuousCallback(cond, affect!; maybe_discontinuity = false) - -sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-8, abstol = 1e-10) -sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +cb = ContinuousCallback(cond, default_affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(cond, default_affect!; maybe_discontinuity = false) -sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10, controller = PI_disco_controller(Rodas5P())) # 240.291 μs (1821 allocations: 71.56 KiB) sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-8, abstol = 1e-10) # 184.625 μs (1029 allocations: 49.23 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject -sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10, controller = PI_disco_controller(Tsit5())) # 79.791 μs (1678 allocations: 73.85 KiB) sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) -# 55.958 μs (1259 allocations: 57.04 KiB) \ No newline at end of file +# 55.958 μs (1259 allocations: 57.04 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject \ No newline at end of file From 191b42f6c5022b3e853f7912ff31cb189a893749 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 15:21:57 -0400 Subject: [PATCH 24/27] bdf controllers --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 78604eae0b5..10a3514119a 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -98,14 +98,18 @@ function bdf_step_reject_controller!(integrator, cache, EEst1) integrator.cache.consfailcnt += 1 integrator.cache.nconsteps = 0 - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt - end - cache.consfailcnt += 1 cache.nconsteps = 0 + + discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end + if cache.consfailcnt > 1 h = h / 2 end From ee0b2bce56106c805a298d43090c1e6ef87e4dda Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 15:22:34 -0400 Subject: [PATCH 25/27] whoops --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 10a3514119a..4527afa10a4 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -95,9 +95,6 @@ end function bdf_step_reject_controller!(integrator, cache, EEst1) k = cache.order h = integrator.dt - integrator.cache.consfailcnt += 1 - integrator.cache.nconsteps = 0 - cache.consfailcnt += 1 cache.nconsteps = 0 From 914a1d4986f354207c328551bd636ad5ff38bfe2 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Fri, 8 May 2026 15:35:08 -0400 Subject: [PATCH 26/27] revert bdf edits --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 4527afa10a4..69d008af77f 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -98,15 +98,12 @@ function bdf_step_reject_controller!(integrator, cache, EEst1) cache.consfailcnt += 1 cache.nconsteps = 0 - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection - if discontinuity_detection - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt - end + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt end - + if cache.consfailcnt > 1 h = h / 2 end From 414eb2c157ec5d976fe6ffa9e1a02228076a598f Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 9 May 2026 12:20:03 -0400 Subject: [PATCH 27/27] composite controller fix --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 11 +++++++---- lib/OrdinaryDiffEqCore/src/integrators/controllers.jl | 8 ++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 69d008af77f..f0e1cf1f328 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -98,10 +98,13 @@ function bdf_step_reject_controller!(integrator, cache, EEst1) cache.consfailcnt += 1 cache.nconsteps = 0 - disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) - if disco_dt != -1 - integrator.dt = disco_dt - return integrator.dt + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end end if cache.consfailcnt > 1 diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 41f9b206b60..eca0a760caf 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -334,7 +334,7 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q) end function step_reject_controller!(integrator, cache::IControllerCache, alg) - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + discontinuity_detection = cache.controller.discontinuity_detection if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 @@ -494,7 +494,7 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + discontinuity_detection = cache.controller.discontinuity_detection if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 @@ -729,7 +729,7 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + discontinuity_detection = cache.controller.discontinuity_detection if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1 @@ -929,7 +929,7 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache - discontinuity_detection = integrator.controller_cache.controller.discontinuity_detection + discontinuity_detection = cache.controller.discontinuity_detection if discontinuity_detection disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) if disco_dt != -1