diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index e825803cf5..ad513ed6f9 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -22,7 +22,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, DAEAlgorithm, _unwrap_val, DummyController, get_fsalfirstlast, generic_solver_docstring, _ad_chunksize_int, _ad_fdtype, _fixup_ad, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 1ff2a32a14..f0e1cf1f32 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -97,6 +97,16 @@ function bdf_step_reject_controller!(integrator, cache, EEst1) h = integrator.dt cache.consfailcnt += 1 cache.nconsteps = 0 + + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end + if cache.consfailcnt > 1 h = h / 2 end @@ -495,4 +505,4 @@ function step_accept_controller!( cache.qwait -= 1 # countdown end return integrator.dt / q -end +end diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index 2987bbc00c..042b425bbc 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -152,6 +152,7 @@ include("cache_utils.jl") include("initialize_dae.jl") include("perform_step/composite_perform_step.jl") +include("disco.jl") include("dense/generic_dense.jl") diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl new file mode 100644 index 0000000000..3c2190c392 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -0,0 +1,74 @@ +function set_discontinuity(u, uprev, integrator, cache) + breakpointθ = find_discontinuity(u, uprev, integrator, cache) + dt = integrator.dt + t = integrator.t + if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 + #println("Discontinuity detected at t = ", t + breakpointθ * dt) + return breakpointθ * dt + end + return -1 +end + +function find_discontinuity(u, uprev, integrator, cache) + cb = integrator.opts.callback + cb === nothing && return -1 + isempty(cb.continuous_callbacks) && return -1 + p = integrator.p + t = integrator.t + dt = integrator.dt + save_idxs = integrator.opts.save_idxs + k = integrator.k + cache = integrator.cache + differential_vars = integrator.differential_vars + θlo = zero(dt) + θhi = one(dt) + bracket = [θlo, θhi] + breakpointθ = -one(dt) + idx = 1 + for i in cb.continuous_callbacks + if (!(i.maybe_discontinuity)) + continue + end + disco_prob = integrator.disco_probs[idx] + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs + disco_zero.tprev = t + disco_zero.f = integrator.f + disco_zero.p = p + if (i isa VectorContinuousCallback) + len_cb = i.len + out_prev = similar(u) + out_curr = similar(u) + i.condition(out_prev, uprev, t, integrator) + i.condition(out_curr, u, t + dt, integrator) + for j in 1:len_cb + if (out_prev[j] * out_curr[j] < zero(out_prev[j])) + disco_zero.ind = j + sol = solve(disco_prob; bracket = bracket) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + if (out_prev * out_curr < zero(out_prev)) + sol = solve(disco_prob; bracket = bracket) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + idx += 1 + end + breakpointθ +end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 859b9cfed9..eca0a760ca 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -185,14 +185,20 @@ Controller cache used by algorithms that manage step-size selection themselves (BDF, Nordsieck, Leaping, …). Holds the scalar error estimate exposed through `get_EEst(integrator)` and a reference to the algorithm cache so existing dispatch on the algorithm cache continues to work. +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. """ mutable struct DummyControllerCache{T, C} <: AbstractControllerCache EEst::T cache::C + discontinuity_detection::Bool end function setup_controller_cache(alg, cache, controller::DummyController, ::Type{E}) where {E} - return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache) + discontinuity_detection = false + return DummyControllerCache{E, typeof(cache)}(oneunit(E), cache, discontinuity_detection) end # Algorithms with integrated controllers (BDF, Nordsieck, …) only define their @@ -236,6 +242,10 @@ the interval `[qmin, qmax]`. A step will be accepted whenever the estimated error `get_EEst(integrator)` is less than or equal to unity. Otherwise, the step is rejected and re-tried with the predicted step size. +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. ## References @@ -250,9 +260,11 @@ struct IController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T + discontinuity_detection::Bool end function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) + discontinuity_detection = false return IController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -260,6 +272,7 @@ function IController(; qmin = 1 // 5, qmax = 10 // 1, qmax_first_step = 10000 // gamma, qsteady_min, qsteady_max, + discontinuity_detection ) end @@ -267,7 +280,7 @@ function IController(alg; kwargs...) return IController(Float64, alg; kwargs...) end -function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing) +function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, discontinuity_detection = nothing) return IController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -275,6 +288,7 @@ function IController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -320,7 +334,15 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q) end function step_reject_controller!(integrator, cache::IControllerCache, alg) - return integrator.dt = cache.dtreject + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end + return integrator.dt = cache.dtreject # TODO this does not look right. end reinit_controller!(integrator::SciMLBase.DEIntegrator, cache::IControllerCache) = nothing @@ -351,7 +373,10 @@ the interval `[qmin, qmax]`. A step will be accepted whenever the estimated error `get_EEst(integrator)` is less than or equal to unity. Otherwise, the step is rejected and re-tried with the predicted step size. - +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. !!! note The coefficients `beta1, beta2` are not scaled by the order of the method, @@ -377,9 +402,11 @@ mutable struct PIController{T} <: AbstractController # TODO remove the mutable o qsteady_min::T qsteady_max::T qoldinit::T + discontinuity_detection::Bool end function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5, qoldinit = 1 // 10^4) + discontinuity_detection = false return PIController{typeof(beta1)}( beta1, beta2, @@ -390,6 +417,7 @@ function PIController(beta1::Real, beta2::Real; qmin = 1 // 5, qmax = 10 // 0, q qsteady_min, qsteady_max, qoldinit, + discontinuity_detection ) end @@ -397,7 +425,7 @@ function PIController(alg; kwargs...) return PIController(Float64, alg; kwargs...) end -function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing) +function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, qoldinit = nothing, discontinuity_detection = nothing) beta2 = beta2 === nothing ? beta2_default(alg) : beta2 beta1 = beta1 === nothing ? beta1_default(alg, beta2) : beta1 qoldinit = qoldinit === nothing ? 1 // 10^4 : qoldinit @@ -410,7 +438,8 @@ function PIController(QT, alg; beta1 = nothing, beta2 = nothing, qmin = nothing, gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, - qoldinit, + qoldinit, + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -465,6 +494,14 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -517,6 +554,11 @@ Some standard controller parameters suggested in the literature are | H211PI | `1//6` | `1//6` | `0` | | H312PID | `1//18` | `1//9` | `1//18` | +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. + !!! note In contrast to the [`PIController`](@ref), the coefficients `beta1, beta2, beta3` @@ -552,18 +594,21 @@ struct PIDController{QT, Limiter} <: AbstractController limiter::Limiter # limiter of the dt factor (before clipping) qsteady_min::QT qsteady_max::QT + discontinuity_detection::Bool end @inline default_dt_factor_limiter(x) = one(x) + atan(x - one(x)) function PIDController(beta1::Real, beta2::Real, beta3::Real = zero(beta1); accept_safety = 0.81, limiter = default_dt_factor_limiter, qsteady_min = 1 // 1, qsteady_max = 6 // 5) beta = map(float, promote(beta1, beta2, beta3)) + discontinuity_detection = false return PIDController{typeof(beta1), typeof(limiter)}( beta, accept_safety, limiter, qsteady_min, qsteady_max, + discontinuity_detection ) end @@ -586,6 +631,7 @@ function PIDController(QT, alg; beta = nothing, accept_safety = 0.81, limiter = limiter, QT(qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min), QT(qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max), + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -683,6 +729,14 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end return integrator.dt *= cache.dt_factor end @@ -738,7 +792,10 @@ integrator.dt / qacc ``` When it rejects, it's the same as the [`IController`](@ref): - +If `discontinuity_detection` is set to true, the algorithm will run the autonomous +discontinuity detection to predict the best next timestep after step rejection. +Otherwise, it follows the default step rejection algorithm. This feature is currently +defaulted off. ```julia if integrator.success_iter == 0 integrator.dt *= 0.1 @@ -754,9 +811,11 @@ struct PredictiveController{T} <: AbstractController gamma::T qsteady_min::T qsteady_max::T + discontinuity_detection::Bool end function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first_step = 10000 // 1, gamma = 9 // 10, qsteady_min = 1 // 1, qsteady_max = 6 // 5) + discontinuity_detection = false return PredictiveController{typeof(qmin)}( # FIXME combined promoted type qmin, qmax, @@ -764,6 +823,7 @@ function PredictiveController(; qmin = float(1 // 5), qmax = 10 // 1, qmax_first gamma, qsteady_min, qsteady_max, + discontinuity_detection ) end @@ -771,7 +831,7 @@ function PredictiveController(alg; kwargs...) return PredictiveController(Float64, alg; kwargs...) end -function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing) +function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_first_step = nothing, gamma = nothing, qsteady_min = nothing, qsteady_max = nothing, discontinuity_detection = nothing) return PredictiveController{QT}( qmin === nothing ? qmin_default(alg) : qmin, qmax === nothing ? qmax_default(alg) : qmax, @@ -779,6 +839,7 @@ function PredictiveController(QT, alg; qmin = nothing, qmax = nothing, qmax_firs gamma === nothing ? gamma_default(alg) : gamma, qsteady_min === nothing ? qsteady_min_default(alg) : qsteady_min, qsteady_max === nothing ? qsteady_max_default(alg) : qsteady_max, + discontinuity_detection === nothing ? false : discontinuity_detection ) end @@ -868,6 +929,14 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache + discontinuity_detection = cache.controller.discontinuity_detection + if discontinuity_detection + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + end return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 180f1af5cc..2c2547ce1d 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -175,6 +175,8 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + #disco_prob::IntervalNonlinearProblem + disco_probs::Vector{IntervalNonlinearProblem} #should we change this? W::WType P::PType sqdt::SqdtType diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index b9a515af0c..d9fa6dc480 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,6 +16,36 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType, FunctionType, tType2, ParameterType} + #integrator_ref::IntegratorType + u₁::uType + callback::callbackType + dt::tType + uprev::uType + u::uType + k::kType + cache::CacheType + idxs::idxsType + differential_vars::varsType + ind::Int + out::outType + f::FunctionType + tprev::tType2 + p::ParameterType +end + +function (z::zero_func_struct)(θ, p) + _ode_addsteps!(z.k, z.tprev, z.uprev, z.u, z.dt, z.f, z.p, z.cache, false, true, false) + ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) + return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind) +end + +@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z) +@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind) + cb.condition(out, u, t, z) + return out[ind] +end + Base.@constprop :aggressive function SciMLBase.__init( prob::Union{ SciMLBase.AbstractODEProblem, @@ -99,6 +129,7 @@ Base.@constprop :aggressive function _ode_init( alias = ODEAliasSpecifier(), initializealg = DefaultInit(), rng = nothing, + disco_probs = nothing, # SDE/RODE fields: accepted here so that SDE packages can delegate to # _ode_init and construct an ODEIntegrator with noise populated. save_noise = false, @@ -626,6 +657,23 @@ Base.@constprop :aggressive function _ode_init( _rng = rng === nothing ? Random.default_rng() : rng + num_probs = 0 + for i in callbacks_internal.continuous_callbacks + if i.maybe_discontinuity + num_probs += 1 + end + end + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) + idx = 1 + for i in callbacks_internal.continuous_callbacks + if i.maybe_discontinuity + u₁ = similar(u) + out = i isa VectorContinuousCallback ? similar(u) : nothing + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out, f, tprev, p) + disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + idx += 1 + end + end # Seed the initial EEst on the controller cache (was previously # `integrator.EEst = oneunit(EEstT)`). set_EEst!(controller_cache, EEst) @@ -640,7 +688,7 @@ Base.@constprop :aggressive function _ode_init( typeof(initializealg), typeof(differential_vars), typeof(controller_cache), typeof(_rng), typeof(W), typeof(P), typeof(sqdt), - typeof(noise), typeof(c), typeof(rate_constants), + typeof(noise), typeof(c), typeof(rate_constants) }( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, @@ -659,7 +707,7 @@ Base.@constprop :aggressive function _ode_init( isout, reeval_fsal, derivative_discontinuity, reinitialize, isdae, opts, stats, initializealg, differential_vars, - fsalfirst, fsallast, _rng, + fsalfirst, fsallast, _rng, disco_probs, W, P, sqdt, noise, c, rate_constants ) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl new file mode 100644 index 0000000000..4f9e94b6e6 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -0,0 +1,292 @@ +using OrdinaryDiffEqCore, DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqFIRK +using Logging +global_logger(ConsoleLogger(stderr, Logging.Error)) +using BenchmarkTools + +predictive_disco_controller(alg) = OrdinaryDiffEqCore.PredictiveController(alg; discontinuity_detection = true) +PI_disco_controller(alg) = OrdinaryDiffEqCore.PIController(alg; discontinuity_detection = true) + +function default_affect!(integrator) + nothing +end + +#TEST 1: SIMPLE DISCONTINUITY +#test example discontinuous at u = 1 +f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] +u0 = [0.1] +tspan = (0.0, 1.5) +prob = ODEProblem(f, u0, tspan) + +#define callback +condition(u, t, integrator) = u[1] - 1 +cb = ContinuousCallback(condition, default_affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(condition, default_affect!; maybe_discontinuity = false) + +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6, controller = predictive_disco_controller(RadauIIA5())) +# 294.458 μs (8082 allocations: 256.59 KiB) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +# 356.708 μs (10024 allocations: 312.08 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Rodas5P())) +# 474.375 μs (16801 allocations: 592.50 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 509.083 μs (18240 allocations: 639.33 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Tsit5())) +# 59.542 μs (7248 allocations: 233.67 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 46.500 μs (7129 allocations: 226.22 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject + +#TEST 2: TWO DISCONTINUITIES +#two discontinuity functions +function f(u, p, t) + if u[1] < 1 + [2u[1]] + elseif u[1] < 2 + [u[1] + 0.2] + else + [-4u[1] + 12] + end +end + +u0 = [0.1] +tspan = (0.0, 2.5) +prob = ODEProblem(f, u0, tspan) + +#define callbacks +condition1(u, t, integrator) = u[1] - 1 +cb1 = ContinuousCallback(condition1, default_affect!; maybe_discontinuity = true) +cb1f = ContinuousCallback(condition1, default_affect!; maybe_discontinuity = false) + +condition2(u, t, integrator) = u[1] - 2 +cb2 = ContinuousCallback(condition2, default_affect!; maybe_discontinuity = true) +cb2f = ContinuousCallback(condition2, default_affect!; maybe_discontinuity = false) +cb = CallbackSet(cb1, cb2) +cb2 = CallbackSet(cb1f, cb2f) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Rodas5P())) +# 1.164 ms (44318 allocations: 1.52 MiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 1.306 ms (51713 allocations: 1.76 MiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6, controller = PI_disco_controller(Tsit5())) +# 279.792 μs (34573 allocations: 1.07 MiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 266.167 μs (39024 allocations: 1.21 MiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject + +#TEST 3: EXPONENTIAL DISCONTINUITY +# multiple exponential regions with sharp transitions +function f_multi_exp!(du, u, p, t) + if u[1] < 0.3 + du[1] = 3 * exp(3 * u[1]) + elseif u[1] < 0.8 + du[1] = exp(u[1]) + else + du[1] = u[1] + end +end + +u0_multi = [0.05] +tspan_multi = (0.0, 1.5) +prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) + +#define callbacks +cond_multi_1(u, t, integrator) = u[1] - 0.3 +cb_multi_1 = ContinuousCallback(cond_multi_1, default_affect!; maybe_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, default_affect!; maybe_discontinuity = false) + +cond_multi_2(u, t, integrator) = u[1] - 0.8 +cb_multi_2 = ContinuousCallback(cond_multi_2, default_affect!; maybe_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, default_affect!; maybe_discontinuity = false) + +cb_multi = CallbackSet(cb_multi_1, cb_multi_2) +cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) + +#disco solve +sol_disco_radau = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = predictive_disco_controller(RadauIIA5())) +# 175.625 μs (1871 allocations: 81.55 KiB) +sol_no_disco_radau = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 142.875 μs (1244 allocations: 59.17 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject + +sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = PI_disco_controller(Rodas5P())) +# 295.834 μs (2216 allocations: 90.70 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +sol_no_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) +# 253.709 μs (1380 allocations: 74.28 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject + +sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9, controller = PI_disco_controller(Tsit5())) +# 127.375 μs (1953 allocations: 87.49 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 95.250 μs (1499 allocations: 73.62 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject + +#TEST 4: STIFF DISCONTINUITY +# very stiff discontinuous system +@test sol_disco_tsit5.retcode == ReturnCode.Success +function f_stiff_disc!(du, u, p, t) + λ = p[1] # stiffness parameter + if u[1] < 0.5 + du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing + else + du[1] = u[1] + end +end + +u0_stiff = [0.1] +tspan_stiff = (0.0, 3.0) +prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) + +#define callback +cond_stiff(u, t, integrator) = u[1] - 0.5 +cb_stiff = ContinuousCallback(cond_stiff, default_affect!; maybe_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, default_affect!; maybe_discontinuity = false) + +#disco solve +sol_disco_radau = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11, controller = predictive_disco_controller(RadauIIA5())) +# 149.167 μs (1819 allocations: 75.19 KiB) +sol_no_disco_radau = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 138.125 μs (1565 allocations: 64.09 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject + +sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11, controller = PI_disco_controller(Tsit5())) +# 93.833 μs (2040 allocations: 80.59 KiB) +sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 82.750 μs (1898 allocations: 72.12 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject + +#TEST 5: DISCONTINUOUS DAE +# discontinuous DAE with mass matrix +# System: M * du/dt = f(u, p, t) +# du[1]/dt = u[2] - u[1] +# 0 = u[1] + u[2] - 1 (algebraic constraint) +function f_dae_disc!(du, u, p, t) + if u[1] < 0.5 + du[1] = 2 * u[2] - u[1] + du[2] = u[1] + u[2] - 1 # algebraic constraint + else + du[1] = -u[1] + u[2] + du[2] = u[1] + u[2] - 1 + end +end + +u0_dae = [0.2, 0.8] +tspan_dae = (0.0, 2.0) + +M_dae = [1.0 0.0; 0.0 0.0] + +f_dae_func = ODEFunction(f_dae_disc!; mass_matrix=M_dae) +prob_dae = ODEProblem(f_dae_func, u0_dae, tspan_dae) + +cond_dae(u, t, integrator) = u[1] - 0.5 +cb_dae = ContinuousCallback(cond_dae, default_affect!; maybe_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, default_affect!; maybe_discontinuity = false) + +radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10, controller = predictive_disco_controller(RadauIIA5())) +# 88.542 μs (870 allocations: 41.86 KiB) +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 73.000 μs (673 allocations: 32.05 KiB) +@test radau_disco.retcode == ReturnCode.Success +@test radau_disco.stats.nreject <= radau_no_disco.stats.nreject + +#TEST 6: VECTOR CALLBACK +function f!(du, u, p, t) + du[1] = -u[1] + du[2] = 0.2*u[1] - 0.1*u[2] +end + +u0 = [3.0, 0.0] +tspan = (0.0, 10.0) +prob = ODEProblem(f!, u0, tspan) + +# u[1] == 2.0 and u[1] == 1.0 +function condition!(out, u, t, integrator) + out[1] = u[1] - 2.0 + out[2] = u[1] - 1.0 +end + +# Discontinuous update to the state when an event fires +function affect!(integrator, idx) + if idx == 1 + # when u[1] crosses 2, kick u[2] up (jump discontinuity) + integrator.u[2] += 5.0 + elseif idx == 2 + # when u[1] crosses 1, reset u[2] + integrator.u[2] = 0.0 + end +end + +cb = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = true) +cb2 = VectorContinuousCallback(condition!, affect!, 2; maybe_discontinuity = false) + +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, controller = predictive_disco_controller(RadauIIA5())) +# 49.125 μs (664 allocations: 32.89 KiB) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2) +# 37.375 μs (531 allocations: 25.23 KiB) +@test sol_disco_radau.retcode == ReturnCode.Success +@test sol_disco_radau.stats.nreject <= sol_no_disco_radau.stats.nreject + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, controller = PI_disco_controller(Rodas5P())) +# 57.333 μs (592 allocations: 31.23 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2) +# 44.250 μs (476 allocations: 23.73 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, controller = PI_disco_controller(Tsit5())) +# 37.833 μs (673 allocations: 31.80 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2) +# 24.958 μs (557 allocations: 24.23 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject + +#TEST 7 +function f!(du, u, p, t) + x1, x2 = u + du[1] = x2 + if x2 < 0.0 + du[2] = -x1 + 1.0 + else + du[2] = -x1 - 1.0 + end +end + +u = [1.5, 0.8] +tspan = (0.0, 2.0) +prob = ODEProblem(f!, u, tspan) + +cond(u, t, integrator) = u[2] +cb = ContinuousCallback(cond, default_affect!; maybe_discontinuity = true) +cb2 = ContinuousCallback(cond, default_affect!; maybe_discontinuity = false) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10, controller = PI_disco_controller(Rodas5P())) +# 240.291 μs (1821 allocations: 71.56 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 184.625 μs (1029 allocations: 49.23 KiB) +@test sol_disco_rosenbrock.retcode == ReturnCode.Success +@test sol_disco_rosenbrock.stats.nreject <= sol_no_disco_rosenbrock.stats.nreject + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10, controller = PI_disco_controller(Tsit5())) +# 79.791 μs (1678 allocations: 73.85 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 55.958 μs (1259 allocations: 57.04 KiB) +@test sol_disco_tsit5.retcode == ReturnCode.Success +@test sol_disco_tsit5.stats.nreject <= sol_no_disco_tsit5.stats.nreject \ No newline at end of file