Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions lib/BoundaryValueDiffEqCore/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,11 @@ function __maybe_allocate_diffcache(x, chunksize, jac_alg)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize)

## get_tmp shows a warning as it should on cache expansion, this behavior however is
## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs

@inline function get_tmp(dc, u)
return Logging.with_logger(Logging.NullLogger()) do
PreallocationTools.get_tmp(dc, u)
end
end
## PreallocationTools.get_tmp may warn on cache expansion (resize), which is expected
## behavior for adaptive BVP solvers. We call it directly here for performance;
## warnings during adaptive cache expansion are suppressed at the __expand_cache! call site.
@inline get_tmp(dc::DiffCache, u) = PreallocationTools.get_tmp(dc, u)
@inline get_tmp(dc, u) = dc

# DiffCache
struct DiffCacheNeeded end
Expand Down
80 changes: 80 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
@inline _maybe_get_tmp(x::DiffCache, u) = PreallocationTools.get_tmp(x, u)
@inline _maybe_get_tmp(x, u) = x

recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:DiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

Expand All @@ -15,6 +18,16 @@ end
end
return y
end

@views function recursive_flatten!(y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector)
i = 0
for xᵢ in x
tmp = PreallocationTools.get_tmp(xᵢ, u)
copyto!(y[(i + 1):(i + length(tmp))], tmp)
i += length(tmp)
end
return y
end
@views function recursive_flatten_twopoint!(y::AbstractVector, x::Vector{<:AbstractArray}, sizes)
x_, xiter = first(x), x[2:end]
copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])])
Expand All @@ -27,6 +40,21 @@ end
return y
end

@views function recursive_flatten_twopoint!(
y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector, sizes
)
x_ = PreallocationTools.get_tmp(first(x), u)
copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])])
i = prod(sizes[1])
for j in 2:length(x)
xᵢ = PreallocationTools.get_tmp(x[j], u)
copyto!(y[(i + 1):(i + length(xᵢ))], xᵢ)
i += length(xᵢ)
end
copyto!(y[(i + 1):(i + prod(sizes[2]))], x_[(end - prod(sizes[2]) + 1):end])
return y
end

@views function recursive_unflatten!(y::Vector{<:AbstractArray}, x::AbstractVector)
i = 0
for yᵢ in y
Expand All @@ -40,6 +68,28 @@ end
return recursive_unflatten!(get_tmp.(y, (x,)), x)
end

# Non-allocating version with pre-allocated output cache.
# When element types match (primal path), fills y_cache in-place.
# When they don't (Dual path), falls back to broadcast allocation.
@views function recursive_unflatten!(
y::Vector{<:DiffCache}, y_cache::Vector{<:AbstractVector{T}}, x::AbstractVector{T}
) where {T}
i = 0
for (j, yᵢ) in enumerate(y)
tmp = PreallocationTools.get_tmp(yᵢ, x)
y_cache[j] = tmp
copyto!(tmp, x[(i + 1):(i + length(tmp))])
i += length(tmp)
end
return y_cache
end

@views function recursive_unflatten!(
y::Vector{<:DiffCache}, y_cache::Vector, x::AbstractVector
)
return recursive_unflatten!(get_tmp.(y, (x,)), x)
end

@views function recursive_unflatten!(y::AbstractVectorOfArray, x::AbstractVector)
i = 0
for yᵢ in y
Expand Down Expand Up @@ -71,6 +121,21 @@ end
return z
end

function __maybe_matmul!(z::AbstractArray, A::AbstractVector{<:AbstractVector}, b,
α = eltype(z)(1), β = eltype(z)(0))
@inbounds for i in eachindex(z)
z[i] *= β
end
@inbounds for j in eachindex(b)
bj = α * b[j]
Aj = A[j]
@simd ivdep for i in eachindex(z)
z[i] += Aj[i] * bj
end
end
return z
end

"""
interval(mesh, t)

Expand Down Expand Up @@ -261,6 +326,21 @@ function __resize!(x::AbstractVector{<:DiffCache}, n, M)
return x
end

function __resize!(x::AbstractVector{<:AbstractVector{<:DiffCache}}, n, M)
N = n - length(x)
N == 0 && return x
if N > 0
chunksize = pickchunksize(M * (N + length(x)))
append!(x, [
[__maybe_allocate_diffcache(dc, chunksize) for dc in last(x)]
for _ in 1:N
])
else
resize!(x, n)
end
return x
end

function __resize!(x::AbstractVectorOfArray, n, M)
N = n - length(x)
N == 0 && return x
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand All @@ -1350,7 +1350,7 @@ end
y_ = recursive_unflatten!(y, u)
resids = [r for r in residual]
Φ!(resids[2:end], cache, y_, u, trait, constraint)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand Down Expand Up @@ -1410,7 +1410,7 @@ end
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait
) where {BC}
y_ = recursive_unflatten!(y, u)
eval_sol.u[1:end] .= y_
copyto!(eval_sol.u, y_)
resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh)
resid_co = Φ(cache, y_, u, trait)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
DiffCacheNeeded, NoDiffCacheNeeded, __split_kwargs,
__concrete_kwargs, __FastShortcutNonlinearPolyalg,
__construct_internal_problem, __internal_solve,
__default_sparsity_detector, __build_cost, __add_singular_term!
__default_sparsity_detector, __build_cost, __add_singular_term!,
_maybe_get_tmp

using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian
Expand Down
17 changes: 9 additions & 8 deletions lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ Here, the ki_interp is the stages in one subinterval.
idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r
idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r
for j in eachindex(k_discrete)
__maybe_matmul!(new_stages.u[j], k_discrete[j].du[:, 1:stage], x_star[idx₁])
__maybe_matmul!(new_stages.u[j], [dc.du for dc in k_discrete[j]], x_star[idx₁])
end
if r > 1
for j in eachindex(k_interp)
Expand Down Expand Up @@ -620,7 +620,7 @@ end
idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r
idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r
for j in eachindex(k_discrete)
__maybe_matmul!(new_stages.u[j], k_discrete[j][:, 1:stage], x_star[idx₁])
__maybe_matmul!(new_stages.u[j], k_discrete[j], x_star[idx₁])
end
if r > 1
for j in eachindex(k_interp)
Expand Down Expand Up @@ -675,7 +675,7 @@ end
(; s_star) = cache.ITU

fᵢ₂_cache .= zero(z)
__maybe_matmul!(fᵢ₂_cache, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(fᵢ₂_cache, [dc.du for dc in k_discrete[i]], w[1:stage])
__maybe_matmul!(
fᵢ₂_cache, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true
)
Expand All @@ -691,7 +691,7 @@ end
(; s_star) = cache.ITU

fᵢ₂_cache .= zero(z)
__maybe_matmul!(fᵢ₂_cache, k_discrete[i][:, 1:stage], w[1:stage])
__maybe_matmul!(fᵢ₂_cache, k_discrete[i], w[1:stage])
__maybe_matmul!(
fᵢ₂_cache, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true
)
Expand All @@ -708,13 +708,14 @@ end
(; stage, k_discrete, k_interp) = cache
(; s_star) = cache.ITU

k_du = [dc.du for dc in k_discrete[i]]
z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(z, k_du, w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true
)
z′ .= zero(z′)
__maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage])
__maybe_matmul!(z′, k_du, w′[1:stage])
__maybe_matmul!(
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true
)
Expand All @@ -731,12 +732,12 @@ end
(; s_star) = cache.ITU

z .= zero(z)
__maybe_matmul!(z, k_discrete[i][:, 1:stage], w[1:stage])
__maybe_matmul!(z, k_discrete[i], w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true
)
z′ .= zero(z′)
__maybe_matmul!(z′, k_discrete[i][:, 1:stage], w′[1:stage])
__maybe_matmul!(z′, k_discrete[i], w′[1:stage])
__maybe_matmul!(
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true
)
Expand Down
Loading
Loading