Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/BoundaryValueDiffEqCore/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ function __maybe_allocate_diffcache(x, chunksize, jac_alg)
return __needs_diffcache(jac_alg) ?
DiffCache(x, chunksize; warn_on_resize = false) : x
end

function __maybe_allocate_diffcache(x::DiffCache, chunksize)
return DiffCache(zero(x.du), chunksize; warn_on_resize = false)
end
Expand Down
72 changes: 62 additions & 10 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
@inline _maybe_get_tmp(x::DiffCache, u) = PreallocationTools.get_tmp(x, u)
@inline _maybe_get_tmp(x, u) = x

recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:DiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

Expand All @@ -15,6 +18,16 @@ end
end
return y
end

@views function recursive_flatten!(y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector)
i = 0
for xᵢ in x
tmp = PreallocationTools.get_tmp(xᵢ, u)
copyto!(y[(i + 1):(i + length(tmp))], tmp)
i += length(tmp)
end
return y
end
@views function recursive_flatten_twopoint!(y::AbstractVector, x::Vector{<:AbstractArray}, sizes)
x_, xiter = first(x), x[2:end]
copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])])
Expand All @@ -27,6 +40,21 @@ end
return y
end

@views function recursive_flatten_twopoint!(
y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector, sizes
)
x_ = PreallocationTools.get_tmp(first(x), u)
copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])])
i = prod(sizes[1])
for j in 2:length(x)
xᵢ = PreallocationTools.get_tmp(x[j], u)
copyto!(y[(i + 1):(i + length(xᵢ))], xᵢ)
i += length(xᵢ)
end
copyto!(y[(i + 1):(i + prod(sizes[2]))], x_[(end - prod(sizes[2]) + 1):end])
return y
end

@views function recursive_unflatten!(y::Vector{<:AbstractArray}, x::AbstractVector)
i = 0
for yᵢ in y
Expand All @@ -40,6 +68,28 @@ end
return recursive_unflatten!(get_tmp.(y, (x,)), x)
end

# Non-allocating version with pre-allocated output cache.
# When element types match (primal path), fills y_cache in-place.
# When they don't (Dual path), falls back to broadcast allocation.
@views function recursive_unflatten!(
y::Vector{<:DiffCache}, y_cache::Vector{<:AbstractVector{T}}, x::AbstractVector{T}
) where {T}
i = 0
for (j, yᵢ) in enumerate(y)
tmp = PreallocationTools.get_tmp(yᵢ, x)
y_cache[j] = tmp
copyto!(tmp, x[(i + 1):(i + length(tmp))])
i += length(tmp)
end
return y_cache
end

@views function recursive_unflatten!(
y::Vector{<:DiffCache}, y_cache::Vector, x::AbstractVector
)
return recursive_unflatten!(get_tmp.(y, (x,)), x)
end

@views function recursive_unflatten!(y::AbstractVectorOfArray, x::AbstractVector)
i = 0
for yᵢ in y
Expand All @@ -56,18 +106,20 @@ function diff!(dx, x)
return dx
end

function __maybe_matmul!(z::AbstractArray, A, b, α = eltype(z)(1), β = eltype(z)(0))
return mul!(z, A, b, α, β)
end
function __maybe_matmul!(z, A, b, α = one(eltype(z)), β = zero(eltype(z)))
# First z = β*z
@inbounds for i in eachindex(z)
z[i] *= β
end

# NOTE: We can implement it as mul! as above but then we pay the cost of moving
# `w` to the GPU too many times. Instead if we iterate of w and w′ we save
# that cost. Our main cost is anyways going to be due to a large `u0` and
# we are going to use GPUs for that
@views function __maybe_matmul!(z, A, b, α = eltype(z)(1), β = eltype(z)(0))
@simd ivdep for j in eachindex(b)
@inbounds @. z = α * A[:, j] * b[j] + β * z
# Then z += α*A*b
@inbounds for j in axes(A, 2)
bj = α * b[j]
for i in axes(A, 1)
z[i] += A[i, j] * bj
end
end

return z
end

Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand All @@ -1359,7 +1359,7 @@ end
y_ = recursive_unflatten!(y, u)
resids = [r for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand Down Expand Up @@ -1419,7 +1419,7 @@ end
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait
) where {BC}
y_ = recursive_unflatten!(y, u)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh)
resid_co = Φ(cache, y_, u, trait)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
DiffCacheNeeded, NoDiffCacheNeeded, __split_kwargs,
__concrete_kwargs, __FastShortcutNonlinearPolyalg,
__construct_internal_problem, __internal_solve,
__default_sparsity_detector, __build_cost, __add_singular_term!
__default_sparsity_detector, __build_cost, __add_singular_term!,
_maybe_get_tmp

using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqMIRK/src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end
T = eltype(u)
for i in eachindex(k_discrete)
K = get_tmp(k_discrete[i], u)
residᵢ = residual[i]
residᵢ = _maybe_get_tmp(residual[i], u)
h = mesh_dt[i]

yᵢ = get_tmp(y[i], u)
Expand Down Expand Up @@ -51,7 +51,7 @@ end
T = eltype(u)
for i in eachindex(k_discrete)
K = get_tmp(k_discrete[i], u)
residᵢ = residual[i]
residᵢ = _maybe_get_tmp(residual[i], u)
h = mesh_dt[i]

yᵢ = get_tmp(y[i], u)
Expand All @@ -77,7 +77,7 @@ end
)
(; c, v, x, b) = TU

tmp = similar(fᵢ_cache)
tmp = fᵢ_cache
T = eltype(u)
for i in eachindex(k_discrete)
K = k_discrete[i]
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqMIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ always update the intermediate solution with discrete solution + discrete stages
(Continuous MIRK: u(meshᵢ + τ*dt) = yᵢ + dt sum br(τ)*kr).
"""
@views function update_eval_sol!(eval_sol::EvalSol, y_, cache::MIRKCache)
eval_sol.u[1:end] .= __restructure_sol(y_, cache.in_size)
eval_sol.cache.k_discrete[1:end] .= cache.k_discrete
eval_sol.cache.k_interp.u[1:end] .= cache.k_interp.u
copyto!(eval_sol.u, __restructure_sol(y_, cache.in_size))
copyto!(eval_sol.cache.k_discrete, cache.k_discrete)
copyto!(eval_sol.cache.k_interp.u, cache.k_interp.u)
interp_setup!(eval_sol.cache)
return nothing
end
Expand Down
42 changes: 22 additions & 20 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
y
y₀
residual
# The following 2 caches are never resized
# The following 3 caches are never resized
fᵢ_cache
fᵢ₂_cache
y_cache # Pre-allocated get_tmp pointers for primal path
errors
new_stages
resid_size
Expand Down Expand Up @@ -72,6 +73,7 @@ function SciMLBase.__init(

fᵢ_cache = __alloc(zero(u0))
fᵢ₂_cache = vec(zero(u0))
y_cache = Vector{Vector{T}}(undef, Nig + 1)

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p; tune_parameters = tune_parameters)
Expand Down Expand Up @@ -237,7 +239,7 @@ function SciMLBase.__init(
return MIRKCache{iip, T, use_both, typeof(diffcache), tune_parameters}(
alg_order(alg), stage, N, size(u0), f, bc, prob_, prob.problem_type, prob.p, alg,
TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y,
y₀, residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages, resid₁_size, prob.singular_term
y₀, residual, fᵢ_cache, fᵢ₂_cache, y_cache, errors, new_stages, resid₁_size, prob.singular_term
, nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, tune_parameters, kwargs...), verbose_spec
)
end
Expand All @@ -257,6 +259,7 @@ function __expand_cache!(cache::MIRKCache{iip, T, use_both}) where {iip, T, use_
__resize!(cache.residual, Nₙ, cache.M)
__resize!(cache.errors, ifelse(use_both, 2 * (Nₙ - 1), (Nₙ - 1)), cache.M)
__resize!(cache.new_stages, Nₙ - 1, cache.M)
resize!(cache.y_cache, Nₙ)
return cache
end

Expand Down Expand Up @@ -445,12 +448,11 @@ end
resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh,
cache, eval_sol, trait::DiffCacheNeeded, constraint
) where {BC}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
y_ = recursive_unflatten!(y, cache.y_cache, u)
Φ!(residual[2:end], cache, y_, u, trait, constraint)
update_eval_sol!(eval_sol, y_, cache)
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
eval_bc_residual!(get_tmp(residual[1], u), pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, residual, u)
return nothing
end

Expand Down Expand Up @@ -483,13 +485,13 @@ end
resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual,
mesh, cache, _, trait::DiffCacheNeeded, constraint
) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
resida = resids[1][1:prod(cache.resid_size[1])]
residb = resids[1][(prod(cache.resid_size[1]) + 1):end]
y_ = recursive_unflatten!(y, cache.y_cache, u)
Φ!(residual[2:end], cache, y, u, trait, constraint)
resid0 = get_tmp(residual[1], u)
resida = resid0[1:prod(cache.resid_size[1])]
residb = resid0[(prod(cache.resid_size[1]) + 1):end]
eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh)
recursive_flatten_twopoint!(resid, resids, cache.resid_size)
recursive_flatten_twopoint!(resid, residual, u, cache.resid_size)
return nothing
end

Expand Down Expand Up @@ -555,20 +557,20 @@ end
@views function __mirk_loss_collocation!(
resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded, constraint
)
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual[2:end]]
Φ!(resids, cache, y_, u, trait, constraint)
recursive_flatten!(resid, resids)
recursive_unflatten!(y, cache.y_cache, u)
collocation_residual = residual[2:end]
Φ!(collocation_residual, cache, y, u, trait, constraint)
recursive_flatten!(resid, collocation_residual, u)
return nothing
end

@views function __mirk_loss_collocation!(
resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded, constraint
)
y_ = recursive_unflatten!(y, u)
resids = [r for r in residual[2:end]]
Φ!(resids, cache, y_, u, trait, constraint)
recursive_flatten!(resid, resids)
collocation_residual = residual[2:end]
Φ!(collocation_residual, cache, y_, u, trait, constraint)
recursive_flatten!(resid, collocation_residual)
return nothing
end

Expand Down
16 changes: 8 additions & 8 deletions lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[3:end], cache, y_, u, p)
EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)
EvalSol.cache.k_discrete[1:end] .= cache.k_discrete
EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)
EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete
copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size))
copyto!(EvalSol.cache.k_discrete, cache.k_discrete)
copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size))
copyto!(EvalDSol.cache.k_discrete, cache.k_discrete)
eval_bc_residual!(resids[1:2], pt, bc, EvalSol, EvalDSol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand All @@ -400,10 +400,10 @@ end
) where {BC}
y_ = recursive_unflatten!(y, u)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)
EvalSol.cache.k_discrete[1:end] .= cache.k_discrete
EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)
EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete
copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size))
copyto!(EvalSol.cache.k_discrete, cache.k_discrete)
copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size))
copyto!(EvalDSol.cache.k_discrete, cache.k_discrete)
resid_bc = eval_bc_residual(pt, bc, EvalSol, EvalDSol, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end
Expand Down
Loading
Loading