Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "3.15.1"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -23,7 +24,9 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Expand All @@ -47,27 +50,35 @@ AMDGPU = "2"
Adapt = "4"
CUDA = "5, 6"
ChainRulesCore = "1"
CommonSolve = "0.2"
DiffEqBase = "7"
Distributed = "1"
DocStringExtensions = "0.9"
ForwardDiff = "1"
GPUArraysCore = "0.2"
JLArrays = "0.3"
KernelAbstractions = "0.9.38"
LinearAlgebra = "1"
LinearSolve = "3"
Metal = "1"
MuladdMacro = "0.2.4"
OpenCL = "0.10"
Parameters = "0.12, 0.13"
Random = "1"
RecursiveArrayTools = "4"
SciMLBase = "3.1"
Setfield = "1"
SimpleDiffEq = "1.11"
SimpleNonlinearSolve = "2"
StaticArrays = "1.9.14"
StaticArraysCore = "1.4"
TOML = "1"
Test = "1"
UnPack = "1"
ZygoteRules = "0.2.7"
julia = "1.10"
oneAPI = "2"
pocl_jll = "7"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down
54 changes: 37 additions & 17 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,58 @@ $(DocStringExtensions.README)
"""
module DiffEqGPU

using DocStringExtensions
using KernelAbstractions
using DocStringExtensions: DocStringExtensions
using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel, CPU
import KernelAbstractions: get_backend, allocate
using SciMLBase, DiffEqBase, LinearAlgebra, Distributed
using ForwardDiff
using SciMLBase: SciMLBase, CallbackSet, CheckInit, ContinuousCallback,
DiscreteCallback, EnsembleDistributed, EnsembleProblem,
EnsembleSerial, EnsembleSolution, EnsembleThreads, ODEFunction,
ODEProblem, ReturnCode, SDEFunction, SDEProblem,
VectorContinuousCallback, remake, terminate!
using DiffEqBase: DiffEqBase, BrownFullBasicInit
using LinearAlgebra: LinearAlgebra, I, LowerTriangular, NoPivot, RowMaximum,
SingularException, UpperTriangular, det
using Distributed: Distributed, nprocs, pmap
using ForwardDiff: ForwardDiff
import ChainRulesCore
import ChainRulesCore: NoTangent
using RecursiveArrayTools
using RecursiveArrayTools: RecursiveArrayTools, VectorOfArray
import ZygoteRules
import Base.Threads
using LinearSolve
using SimpleNonlinearSolve
using Base: setindex
using CommonSolve: solve
using LinearSolve: LinearSolve
using SimpleNonlinearSolve: SimpleNonlinearSolve
import SimpleNonlinearSolve: SimpleTrustRegion
#For gpu_tsit5
using Adapt, SimpleDiffEq, StaticArrays
using Parameters, MuladdMacro
using Random
using Setfield
using ForwardDiff
import StaticArrays: StaticVecOrMat, @_inline_meta
# import LinearAlgebra: \
import StaticArrays: LU, StaticLUMatrix, arithmetic_closure
using Adapt: Adapt, adapt
using SimpleDiffEq: SimpleDiffEq, GPUSimpleATsit5, GPUSimpleAVern7, GPUSimpleAVern9,
GPUSimpleTsit5, GPUSimpleVern7, GPUSimpleVern9, SimpleEM
using StaticArrays: StaticArrays
using StaticArraysCore: MArray, MMatrix, SArray, SMatrix, SVector, Size,
StaticMatrix, StaticVector, similar_type
using Parameters: Parameters
using MuladdMacro: MuladdMacro, @muladd
using Random: Random
using Setfield: Setfield, @set, @set!
using UnPack: @unpack
# StaticArraysCore-owned type alias (re-exported by StaticArrays); used in dispatch.
import StaticArrays: StaticVecOrMat
# Non-public StaticArrays internals used by the vendored GPU LU/linsolve kernels.
import StaticArrays: @_inline_meta, LU, StaticLUMatrix
import SciMLBase: ImmutableODEProblem

abstract type EnsembleArrayAlgorithm <: SciMLBase.EnsembleAlgorithm end
abstract type EnsembleKernelAlgorithm <: SciMLBase.EnsembleAlgorithm end

##Solvers for EnsembleGPUKernel
abstract type GPUODEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type GPUSDEAlgorithm <: DiffEqBase.AbstractSDEAlgorithm end
abstract type GPUODEAlgorithm <: SciMLBase.AbstractODEAlgorithm end
abstract type GPUSDEAlgorithm <: SciMLBase.AbstractSDEAlgorithm end
abstract type GPUODEImplicitAlgorithm{AD} <: GPUODEAlgorithm end

_unwrap_val(B) = B
_unwrap_val(::Val{B}) where {B} = B

include("ensemblegpuarray/callbacks.jl")
include("ensemblegpuarray/kernels.jl")
include("ensemblegpuarray/problem_generation.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ alg_order(alg::GPUEM) = 1
alg_order(alg::GPUSIEA) = 2

function finite_diff_jac(f, jac_prototype, x)
dx = sqrt(eps(DiffEqBase.RecursiveArrayTools.recursive_bottom_eltype(x)))
dx = sqrt(eps(RecursiveArrayTools.recursive_bottom_eltype(x)))
jac = MMatrix{size(x, 1), size(x, 1), eltype(x)}(1I)
for i in eachindex(x)
x_dx = convert(MArray, x)
Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/gpukernel_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct GPUKvaerno5{AD} <: GPUODEImplicitAlgorithm{AD} end
for Alg in [:GPURosenbrock23, :GPURodas4, :GPURodas5P, :GPUKvaerno3, :GPUKvaerno5]
@eval begin
function $Alg(; autodiff = Val{true}())
return $Alg{SciMLBase._unwrap_val(autodiff)}()
return $Alg{_unwrap_val(autodiff)}()
end
end
end
Expand Down
28 changes: 14 additions & 14 deletions src/ensemblegpukernel/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function build_adaptive_controller_cache(alg::A, ::Type{T}) where {A, T}
end

@inline function savevalues!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S,
Expand Down Expand Up @@ -50,7 +50,7 @@ end
end

@inline function DiffEqBase.terminate!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP, S,
T,
Expand All @@ -67,7 +67,7 @@ end
end

@inline function apply_discrete_callback!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand All @@ -91,7 +91,7 @@ end
end

@inline function apply_discrete_callback!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand All @@ -111,7 +111,7 @@ end
end

@inline function apply_discrete_callback!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand All @@ -137,7 +137,7 @@ end
end

@inline function apply_discrete_callback!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand All @@ -156,7 +156,7 @@ end
end

@inline function interpolate(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S,
Expand All @@ -176,7 +176,7 @@ end
end

@inline function _change_t_via_interpolation!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S,
Expand Down Expand Up @@ -205,7 +205,7 @@ end
end
end
@inline function DiffEqBase.change_t_via_interpolation!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S,
Expand All @@ -227,7 +227,7 @@ end
end

@inline function apply_callback!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType, IIP,
S, T,
},
Expand Down Expand Up @@ -272,7 +272,7 @@ end
end

@inline function handle_callbacks!(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP, S, T,
},
Expand Down Expand Up @@ -384,7 +384,7 @@ end
end

@inline function DiffEqBase.find_callback_time(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S,
Expand Down Expand Up @@ -442,7 +442,7 @@ end
end

@inline function SciMLBase.get_tmp_cache(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand All @@ -458,7 +458,7 @@ end
end

@inline function DiffEqBase.get_condition(
integrator::DiffEqBase.AbstractODEIntegrator{
integrator::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP,
S, T,
Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/integrators/nonstiff/interpolants.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Default: Hermite Interpolation
@inline @muladd function _ode_interpolant(
Θ, dt, y₀,
integ::DiffEqBase.AbstractODEIntegrator{
integ::SciMLBase.AbstractODEIntegrator{
AlgType,
IIP, S, T,
}
Expand Down
12 changes: 6 additions & 6 deletions src/ensemblegpukernel/integrators/nonstiff/types.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Fixed TimeStep Integrator

mutable struct GPUTsit5Integrator{IIP, S, T, ST, P, F, TS, CB, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down Expand Up @@ -52,7 +52,7 @@ DiffEqBase.isinplace(::GPUT5I{IIP}) where {IIP} = IIP
## Adaptive TimeStep Integrator

mutable struct GPUATsit5Integrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down Expand Up @@ -108,7 +108,7 @@ end
## Vern7

mutable struct GPUV7Integrator{IIP, S, T, ST, P, F, TS, CB, TabType, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down Expand Up @@ -152,7 +152,7 @@ const GPUV7I = GPUV7Integrator
end

mutable struct GPUAV7Integrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, TabType, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down Expand Up @@ -205,7 +205,7 @@ end
## Vern9

mutable struct GPUV9Integrator{IIP, S, T, ST, P, F, TS, CB, TabType, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down Expand Up @@ -249,7 +249,7 @@ const GPUV9I = GPUV9Integrator
end

mutable struct GPUAV9Integrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, TabType, AlgType} <:
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
SciMLBase.AbstractODEIntegrator{AlgType, IIP, S, T}
alg::AlgType
f::F # eom
uprev::S # previous state
Expand Down
Loading
Loading