Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@ repo = "https://github.com/SciML/SimpleDiffEq.jl.git"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
DiffEqBase = "6.194.0, 7"
ExplicitImports = "1.14.0"
DiffEqBase = "7"
FastBroadcast = "0.3, 0.4, 1, 2"
JLArrays = "0.1, 0.2, 0.3"
LinearAlgebra = "1"
MuladdMacro = "0.2.4"
OrdinaryDiffEq = "6.106.0, 7"
OrdinaryDiffEqLowOrderRK = "1, 2"
Expand All @@ -25,13 +28,13 @@ Parameters = "0.12, 0.13"
RecursiveArrayTools = "3.37.0, 4"
Reexport = "1.2.2"
SafeTestsets = "0.1"
SciMLTesting = "1"
SciMLBase = "3.27"
SciMLTesting = "1.6"
StaticArrays = "1.9.14"
Test = "1"
julia = "1.10"

[extras]
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
Expand All @@ -42,4 +45,4 @@ SciMLTesting = "09d9d899-5365-40a9-917a-5f67fddea283"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ExplicitImports", "JLArrays", "OrdinaryDiffEq", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqTsit5", "OrdinaryDiffEqVerner", "SafeTestsets", "SciMLTesting", "Test"]
test = ["JLArrays", "OrdinaryDiffEq", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqTsit5", "OrdinaryDiffEqVerner", "SafeTestsets", "SciMLTesting", "Test"]
11 changes: 6 additions & 5 deletions src/SimpleDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ module SimpleDiffEq

using Reexport: @reexport
using MuladdMacro: @muladd
@reexport using FastBroadcast: @..
@reexport using DiffEqBase: DiffEqBase, ODEProblem, SDEProblem, DiscreteProblem,
isinplace, reinit!, ODE_DEFAULT_NORM,
set_t!, solve, step!, init, @..,
set_t!, solve, step!, init, isdiscrete
@reexport using SciMLBase: SciMLBase, build_solution, is_diagonal_noise,
AbstractSDEAlgorithm, AbstractODEAlgorithm,
AbstractODEIntegrator, DEIntegrator, ConstantInterpolation,
__init, __solve, build_solution, has_analytic,
calculate_solution_errors!, is_diagonal_noise,
AbstractSDEAlgorithm, AbstractODEAlgorithm, isdiscrete, SciMLBase
import DiffEqBase.SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive
__init, __solve, has_analytic, calculate_solution_errors!
import SciMLBase: allows_arbitrary_number_types, allowscomplex, isautodifferentiable, isadaptive
# `derivative_discontinuity!` was introduced in DiffEqBase v7 / SciMLBase v3,
# replacing the older `u_modified!`. Support both branches so the package can
# be used with either DiffEqBase v6 or v7.
Expand Down
12 changes: 6 additions & 6 deletions src/euler/euler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct SimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export SimpleEuler

mutable struct SimpleEulerIntegrator{IIP, S, T, P, F} <:
DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T}
SciMLBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T}
f::F # ..................................... Equations of motion
uprev::S # .......................................... Previous state
u::S # ........................................... Current state
Expand All @@ -71,7 +71,7 @@ DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP
# Initialization
################################################################################

function DiffEqBase.__init(
function SciMLBase.__init(
prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"), kwargs...
)
Expand All @@ -85,7 +85,7 @@ function DiffEqBase.__init(
)
end

function DiffEqBase.__solve(
function SciMLBase.__solve(
prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"), kwargs...
)
Expand All @@ -107,10 +107,10 @@ function DiffEqBase.__solve(
us[i + 1] = _copy(integ.u)
end

sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false)
sol = SciMLBase.build_solution(prob, alg, ts, us, calculate_error = false)

DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol;
timeseries_errors = true,
dense_errors = false
Expand Down
6 changes: 3 additions & 3 deletions src/euler/gpueuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ export GPUSimpleEuler
us[i] = u
end

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand Down
18 changes: 9 additions & 9 deletions src/euler/loopeuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export LoopEuler

# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(
@muladd function SciMLBase.__solve(
prob::ODEProblem{uType, tType, false},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
Expand Down Expand Up @@ -92,13 +92,13 @@ export LoopEuler

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand All @@ -107,7 +107,7 @@ end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
# Use @.. for simd ivdep
@muladd function DiffEqBase.solve(
prob::ODEProblem{uType, tType, true},
alg::LoopEuler;
Expand Down Expand Up @@ -147,19 +147,19 @@ end
for i in 2:length(ts)
t = ts[i]
f(k, u, p, t)
DiffEqBase.@.. u = u + dt * k
@.. u = u + dt * k
save_everystep && (us[i] = copy(u))
end

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand Down
18 changes: 9 additions & 9 deletions src/euler_maruyama.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ sol = solve(prob, SimpleEM(), dt = 0.01)

- [`SDEProblem`](@ref) for problem setup
"""
struct SimpleEM <: DiffEqBase.AbstractSDEAlgorithm end
struct SimpleEM <: SciMLBase.AbstractSDEAlgorithm end
export SimpleEM

@muladd function DiffEqBase.solve(
Expand Down Expand Up @@ -76,7 +76,7 @@ export SimpleEM
end
end

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, t, u,
calculate_error = false
)
Expand All @@ -97,9 +97,9 @@ end
tspan = prob.tspan
p = prob.p
ftmp = zero(u0)
gtmp = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) : zero(prob.noise_rate_prototype)
gtmp2 = DiffEqBase.is_diagonal_noise(prob) ? nothing : zero(u0)
dW = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) :
gtmp = SciMLBase.is_diagonal_noise(prob) ? zero(u0) : zero(prob.noise_rate_prototype)
gtmp2 = SciMLBase.is_diagonal_noise(prob) ? nothing : zero(u0)
dW = SciMLBase.is_diagonal_noise(prob) ? zero(u0) :
false .* prob.noise_rate_prototype[1, :]

@inbounds begin
Expand All @@ -116,15 +116,15 @@ end
g(gtmp, uprev, p, tprev)
@. dW = randn(eltype(dW))

if DiffEqBase.is_diagonal_noise(prob)
DiffEqBase.@.. u[i] = uprev + ftmp * dt + sqdt * gtmp * dW
if SciMLBase.is_diagonal_noise(prob)
@.. u[i] = uprev + ftmp * dt + sqdt * gtmp * dW
else
mul!(gtmp2, gtmp, dW)
DiffEqBase.@.. u[i] = uprev + ftmp * dt + sqdt * gtmp2
@.. u[i] = uprev + ftmp * dt + sqdt * gtmp2
end
end

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, t, u,
calculate_error = false
)
Expand Down
18 changes: 9 additions & 9 deletions src/functionmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export SimpleFunctionMap
SciMLBase.isdiscrete(alg::SimpleFunctionMap) = true

# ConstantCache version
function DiffEqBase.__solve(
function SciMLBase.__solve(
prob::DiffEqBase.DiscreteProblem{uType, tupType, false},
alg::SimpleFunctionMap;
calculate_values = true, kwargs...
Expand All @@ -57,15 +57,15 @@ function DiffEqBase.__solve(
u[i] = f(u[i - 1], p, t[i])
end
end
return sol = DiffEqBase.build_solution(
return sol = SciMLBase.build_solution(
prob, alg, t, u, dense = false,
interp = DiffEqBase.ConstantInterpolation(t, u),
interp = SciMLBase.ConstantInterpolation(t, u),
calculate_error = false
)
end

# Cache version
function DiffEqBase.__solve(
function SciMLBase.__solve(
prob::DiscreteProblem{uType, tupType, true},
alg::SimpleFunctionMap;
calculate_values = true, kwargs...
Expand All @@ -86,9 +86,9 @@ function DiffEqBase.__solve(
f(u[i], u[i - 1], p, t[i])
end
end
return sol = DiffEqBase.build_solution(
return sol = SciMLBase.build_solution(
prob, alg, t, u, dense = false,
interp = DiffEqBase.ConstantInterpolation(t, u),
interp = SciMLBase.ConstantInterpolation(t, u),
calculate_error = false
)
end
Expand All @@ -97,7 +97,7 @@ end

# Integrator version
mutable struct DiscreteIntegrator{F, IIP, uType, tType, P, S} <:
DiffEqBase.DEIntegrator{SimpleFunctionMap, IIP, uType, tType}
SciMLBase.DEIntegrator{SimpleFunctionMap, IIP, uType, tType}
f::F
u::uType
t::tType
Expand All @@ -108,12 +108,12 @@ mutable struct DiscreteIntegrator{F, IIP, uType, tType, P, S} <:
tdir::tType
end

function DiffEqBase.__init(
function SciMLBase.__init(
prob::DiscreteProblem,
alg::SimpleFunctionMap;
kwargs...
)
sol = DiffEqBase.__solve(prob, alg; calculate_values = false)
sol = SciMLBase.__solve(prob, alg; calculate_values = false)
F = typeof(prob.f)
IIP = isinplace(prob)
uType = typeof(prob.u0)
Expand Down
6 changes: 3 additions & 3 deletions src/rk4/gpurk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ export GPUSimpleRK4
us[i] = u
end

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand Down
24 changes: 12 additions & 12 deletions src/rk4/looprk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export LoopRK4

# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(
@muladd function SciMLBase.__solve(
prob::ODEProblem{uType, tType, false},
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
Expand Down Expand Up @@ -100,13 +100,13 @@ export LoopRK4

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand All @@ -115,7 +115,7 @@ end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
# Use @.. for simd ivdep
@muladd function DiffEqBase.solve(
prob::ODEProblem{uType, tType, true},
alg::LoopRK4;
Expand Down Expand Up @@ -162,25 +162,25 @@ end
uprev .= u
t = ts[i]
f(k1, u, p, t)
DiffEqBase.@.. u = uprev + dt * half * k1
@.. u = uprev + dt * half * k1
f(k2, u, p, t + half * dt)
DiffEqBase.@.. u = uprev + dt * half * k2
@.. u = uprev + dt * half * k2
f(k3, u, p, t + half * dt)
DiffEqBase.@.. u = uprev + dt * k3
@.. u = uprev + dt * k3
f(k4, u, p, t + dt)
DiffEqBase.@.. u = uprev + dt * sixth * (k1 + 2k2 + 2k3 + k4)
@.. u = uprev + dt * sixth * (k1 + 2k2 + 2k3 + k4)
save_everystep && (us[i] = copy(u))
end

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(
sol = SciMLBase.build_solution(
prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false
)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(
sol; timeseries_errors = true,
dense_errors = false
)
Expand Down
Loading
Loading