diff --git a/lib/BoundaryValueDiffEqCore/src/types.jl b/lib/BoundaryValueDiffEqCore/src/types.jl index 1e35a654f..8b336097e 100644 --- a/lib/BoundaryValueDiffEqCore/src/types.jl +++ b/lib/BoundaryValueDiffEqCore/src/types.jl @@ -175,14 +175,11 @@ function __maybe_allocate_diffcache(x, chunksize, jac_alg) end __maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize) -## get_tmp shows a warning as it should on cache expansion, this behavior however is -## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs - -@inline function get_tmp(dc, u) - return Logging.with_logger(Logging.NullLogger()) do - PreallocationTools.get_tmp(dc, u) - end -end +## PreallocationTools.get_tmp may warn on cache expansion (resize), which is expected +## behavior for adaptive BVP solvers. We call it directly here for performance; +## warnings during adaptive cache expansion are suppressed at the __expand_cache! call site. +@inline get_tmp(dc::DiffCache, u) = PreallocationTools.get_tmp(dc, u) +@inline get_tmp(dc, u) = dc # DiffCache struct DiffCacheNeeded end diff --git a/lib/BoundaryValueDiffEqCore/src/utils.jl b/lib/BoundaryValueDiffEqCore/src/utils.jl index c2d6b15aa..9814dcf22 100644 --- a/lib/BoundaryValueDiffEqCore/src/utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/utils.jl @@ -1,3 +1,6 @@ +@inline _maybe_get_tmp(x::DiffCache, u) = PreallocationTools.get_tmp(x, u) +@inline _maybe_get_tmp(x, u) = x + recursive_length(x::Vector{<:AbstractArray}) = sum(length, x) recursive_length(x::Vector{<:DiffCache}) = sum(xᵢ -> length(xᵢ.u), x) @@ -15,6 +18,16 @@ end end return y end + +@views function recursive_flatten!(y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector) + i = 0 + for xᵢ in x + tmp = PreallocationTools.get_tmp(xᵢ, u) + copyto!(y[(i + 1):(i + length(tmp))], tmp) + i += length(tmp) + end + return y +end @views function recursive_flatten_twopoint!(y::AbstractVector, x::Vector{<:AbstractArray}, sizes) x_, xiter = first(x), x[2:end] copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])]) @@ -27,6 +40,21 @@ end return y end +@views function recursive_flatten_twopoint!( + y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector, sizes + ) + x_ = PreallocationTools.get_tmp(first(x), u) + copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])]) + i = prod(sizes[1]) + for j in 2:length(x) + xᵢ = PreallocationTools.get_tmp(x[j], u) + copyto!(y[(i + 1):(i + length(xᵢ))], xᵢ) + i += length(xᵢ) + end + copyto!(y[(i + 1):(i + prod(sizes[2]))], x_[(end - prod(sizes[2]) + 1):end]) + return y +end + @views function recursive_unflatten!(y::Vector{<:AbstractArray}, x::AbstractVector) i = 0 for yᵢ in y @@ -40,6 +68,28 @@ end return recursive_unflatten!(get_tmp.(y, (x,)), x) end +# Non-allocating version with pre-allocated output cache. +# When element types match (primal path), fills y_cache in-place. +# When they don't (Dual path), falls back to broadcast allocation. +@views function recursive_unflatten!( + y::Vector{<:DiffCache}, y_cache::Vector{<:AbstractVector{T}}, x::AbstractVector{T} + ) where {T} + i = 0 + for (j, yᵢ) in enumerate(y) + tmp = PreallocationTools.get_tmp(yᵢ, x) + y_cache[j] = tmp + copyto!(tmp, x[(i + 1):(i + length(tmp))]) + i += length(tmp) + end + return y_cache +end + +@views function recursive_unflatten!( + y::Vector{<:DiffCache}, y_cache::Vector, x::AbstractVector + ) + return recursive_unflatten!(get_tmp.(y, (x,)), x) +end + @views function recursive_unflatten!(y::AbstractVectorOfArray, x::AbstractVector) i = 0 for yᵢ in y @@ -71,6 +121,21 @@ end return z end +function __maybe_matmul!(z::AbstractArray, A::AbstractVector{<:AbstractVector}, b, + α = eltype(z)(1), β = eltype(z)(0)) + @inbounds for i in eachindex(z) + z[i] *= β + end + @inbounds for j in eachindex(b) + bj = α * b[j] + Aj = A[j] + @simd ivdep for i in eachindex(z) + z[i] += Aj[i] * bj + end + end + return z +end + """ interval(mesh, t) @@ -261,6 +326,21 @@ function __resize!(x::AbstractVector{<:DiffCache}, n, M) return x end +function __resize!(x::AbstractVector{<:AbstractVector{<:DiffCache}}, n, M) + N = n - length(x) + N == 0 && return x + if N > 0 + chunksize = pickchunksize(M * (N + length(x))) + append!(x, [ + [__maybe_allocate_diffcache(dc, chunksize) for dc in last(x)] + for _ in 1:N + ]) + else + resize!(x, n) + end + return x +end + function __resize!(x::AbstractVectorOfArray, n, M) N = n - length(x) N == 0 && return x diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index adb324f74..dda70698a 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -1337,7 +1337,7 @@ end y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[2:end], cache, y_, u, trait, constraint) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -1350,7 +1350,7 @@ end y_ = recursive_unflatten!(y, u) resids = [r for r in residual] Φ!(resids[2:end], cache, y_, u, trait, constraint) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -1410,7 +1410,7 @@ end u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait ) where {BC} y_ = recursive_unflatten!(y, u) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh) resid_co = Φ(cache, y_, u, trait) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 927f48ec2..b6216c3e6 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -24,7 +24,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, DiffCacheNeeded, NoDiffCacheNeeded, __split_kwargs, __concrete_kwargs, __FastShortcutNonlinearPolyalg, __construct_internal_problem, __internal_solve, - __default_sparsity_detector, __build_cost, __add_singular_term! + __default_sparsity_detector, __build_cost, __add_singular_term!, + _maybe_get_tmp using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian diff --git a/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl index 2a8e6b8e7..38252b53c 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl @@ -583,7 +583,7 @@ Here, the ki_interp is the stages in one subinterval. idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r for j in eachindex(k_discrete) - __maybe_matmul!(new_stages.u[j], k_discrete[j].du[:, 1:stage], x_star[idx₁]) + __maybe_matmul!(new_stages.u[j], [dc.du for dc in k_discrete[j]], x_star[idx₁]) end if r > 1 for j in eachindex(k_interp) @@ -620,7 +620,7 @@ end idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r for j in eachindex(k_discrete) - __maybe_matmul!(new_stages.u[j], k_discrete[j][:, 1:stage], x_star[idx₁]) + __maybe_matmul!(new_stages.u[j], k_discrete[j], x_star[idx₁]) end if r > 1 for j in eachindex(k_interp) @@ -675,7 +675,7 @@ end (; s_star) = cache.ITU fᵢ₂_cache .= zero(z) - __maybe_matmul!(fᵢ₂_cache, k_discrete[i].du[:, 1:stage], w[1:stage]) + __maybe_matmul!(fᵢ₂_cache, [dc.du for dc in k_discrete[i]], w[1:stage]) __maybe_matmul!( fᵢ₂_cache, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true ) @@ -691,7 +691,7 @@ end (; s_star) = cache.ITU fᵢ₂_cache .= zero(z) - __maybe_matmul!(fᵢ₂_cache, k_discrete[i][:, 1:stage], w[1:stage]) + __maybe_matmul!(fᵢ₂_cache, k_discrete[i], w[1:stage]) __maybe_matmul!( fᵢ₂_cache, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true ) @@ -708,13 +708,14 @@ end (; stage, k_discrete, k_interp) = cache (; s_star) = cache.ITU + k_du = [dc.du for dc in k_discrete[i]] z .= zero(z) - __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) + __maybe_matmul!(z, k_du, w[1:stage]) __maybe_matmul!( z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true ) z′ .= zero(z′) - __maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage]) + __maybe_matmul!(z′, k_du, w′[1:stage]) __maybe_matmul!( z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true ) @@ -731,12 +732,12 @@ end (; s_star) = cache.ITU z .= zero(z) - __maybe_matmul!(z, k_discrete[i][:, 1:stage], w[1:stage]) + __maybe_matmul!(z, k_discrete[i], w[1:stage]) __maybe_matmul!( z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true ) z′ .= zero(z′) - __maybe_matmul!(z′, k_discrete[i][:, 1:stage], w′[1:stage]) + __maybe_matmul!(z′, k_discrete[i], w′[1:stage]) __maybe_matmul!( z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true ) diff --git a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl index 4a7216d4b..103af3836 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl @@ -5,53 +5,66 @@ function Φ!(residual, cache::MIRKCache, y, u, trait, constraint) ) end -@views function Φ!( +function Φ!( residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int, f_prototype, singular_term, ::DiffCacheNeeded, ::Val{true} ) (; c, v, x, b) = TU L_f_prototype = length(f_prototype) - tmpy, - tmpu = get_tmp(fᵢ_cache, u)[1:L_f_prototype], - get_tmp(fᵢ_cache, u)[(L_f_prototype + 1):end] + tmpy = @view get_tmp(fᵢ_cache, u)[1:L_f_prototype] + tmpu = @view get_tmp(fᵢ_cache, u)[(L_f_prototype + 1):end] T = eltype(u) for i in eachindex(k_discrete) - K = get_tmp(k_discrete[i], u) - residᵢ = residual[i] + residᵢ = _maybe_get_tmp(residual[i], u) h = mesh_dt[i] yᵢ = get_tmp(y[i], u) yᵢ₊₁ = get_tmp(y[i + 1], u) - yᵢ, uᵢ = yᵢ[1:L_f_prototype], yᵢ[(L_f_prototype + 1):end] - yᵢ₊₁, uᵢ₊₁ = yᵢ₊₁[1:L_f_prototype], yᵢ₊₁[(L_f_prototype + 1):end] + yᵢ_f = @view yᵢ[1:L_f_prototype] + uᵢ = @view yᵢ[(L_f_prototype + 1):end] + yᵢ₊₁_f = @view yᵢ₊₁[1:L_f_prototype] + uᵢ₊₁ = @view yᵢ₊₁[(L_f_prototype + 1):end] for r in 1:stage - @. tmpy = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ + @. tmpy = (1 - v[r]) * yᵢ_f + v[r] * yᵢ₊₁_f @. tmpu = (1 - v[r]) * uᵢ + v[r] * uᵢ₊₁ - __maybe_matmul!(tmpy, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) - f!(K[:, r], vcat(tmpy, tmpu), p, mesh[i] + c[r] * h) + @inbounds for j in 1:(r - 1) + Kⱼ = get_tmp(k_discrete[i][j], u) + xrj = h * x[r, j] + @simd ivdep for k in 1:L_f_prototype + tmpy[k] += Kⱼ[k] * xrj + end + end + K_r = get_tmp(k_discrete[i][r], u) + f!(K_r, vcat(tmpy, tmpu), p, mesh[i] + c[r] * h) end # Update residual - @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + @. residᵢ = yᵢ₊₁_f - yᵢ_f + @inbounds for j in 1:stage + Kⱼ = get_tmp(k_discrete[i][j], u) + mbhj = -h * b[j] + @simd ivdep for k in 1:L_f_prototype + residᵢ[k] += Kⱼ[k] * mbhj + end + end end end -@views function Φ!( +function Φ!( residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int, _, singular_term, ::DiffCacheNeeded, constraint::Val{false} ) (; c, v, x, b) = TU tmp = get_tmp(fᵢ_cache, u) + N = length(tmp) T = eltype(u) for i in eachindex(k_discrete) - K = get_tmp(k_discrete[i], u) - residᵢ = residual[i] + residᵢ = _maybe_get_tmp(residual[i], u) h = mesh_dt[i] yᵢ = get_tmp(y[i], u) @@ -59,28 +72,41 @@ end for r in 1:stage @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + @inbounds for j in 1:(r - 1) + Kⱼ = get_tmp(k_discrete[i][j], u) + xrj = h * x[r, j] + @simd ivdep for k in 1:N + tmp[k] += Kⱼ[k] * xrj + end + end t = mesh[i] + c[r] * h - f!(K[:, r], tmp, p, t) - __add_singular_term!(K[:, r], singular_term, tmp, t) + K_r = get_tmp(k_discrete[i][r], u) + f!(K_r, tmp, p, t) + __add_singular_term!(K_r, singular_term, tmp, t) end # Update residual @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + @inbounds for j in 1:stage + Kⱼ = get_tmp(k_discrete[i][j], u) + mbhj = -h * b[j] + @simd ivdep for k in 1:N + residᵢ[k] += Kⱼ[k] * mbhj + end + end end end -@views function Φ!( +function Φ!( residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int, _, singular_term, ::NoDiffCacheNeeded, ::Val{false} ) (; c, v, x, b) = TU - tmp = similar(fᵢ_cache) + tmp = fᵢ_cache + N = length(tmp) T = eltype(u) for i in eachindex(k_discrete) - K = k_discrete[i] residᵢ = residual[i] h = mesh_dt[i] @@ -89,15 +115,28 @@ end for r in 1:stage @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + @inbounds for j in 1:(r - 1) + Kⱼ = k_discrete[i][j] + xrj = h * x[r, j] + @simd ivdep for k in 1:N + tmp[k] += Kⱼ[k] * xrj + end + end t = mesh[i] + c[r] * h - f!(K[:, r], tmp, p, t) - __add_singular_term!(K[:, r], singular_term, tmp, t) + K_r = k_discrete[i][r] + f!(K_r, tmp, p, t) + __add_singular_term!(K_r, singular_term, tmp, t) end # Update residual @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + @inbounds for j in 1:stage + Kⱼ = k_discrete[i][j] + mbhj = -h * b[j] + @simd ivdep for k in 1:N + residᵢ[k] += Kⱼ[k] * mbhj + end + end end end @@ -108,16 +147,16 @@ function Φ(cache::MIRKCache, y, u, trait) ) end -@views function Φ( +function Φ( fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int, singular_term, ::DiffCacheNeeded ) (; c, v, x, b) = TU residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]] tmp = get_tmp(fᵢ_cache, u) + N = length(tmp) T = eltype(u) for i in eachindex(k_discrete) - K = get_tmp(k_discrete[i], u) residᵢ = residuals[i] h = mesh_dt[i] @@ -126,30 +165,43 @@ end for r in 1:stage @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + @inbounds for j in 1:(r - 1) + Kⱼ = get_tmp(k_discrete[i][j], u) + xrj = h * x[r, j] + @simd ivdep for k in 1:N + tmp[k] += Kⱼ[k] * xrj + end + end t = mesh[i] + c[r] * h - K[:, r] .= f(tmp, p, t) - __add_singular_term!(K[:, r], singular_term, tmp, t) + K_r = get_tmp(k_discrete[i][r], u) + K_r .= f(tmp, p, t) + __add_singular_term!(K_r, singular_term, tmp, t) end # Update residual @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + @inbounds for j in 1:stage + Kⱼ = get_tmp(k_discrete[i][j], u) + mbhj = -h * b[j] + @simd ivdep for k in 1:N + residᵢ[k] += Kⱼ[k] * mbhj + end + end end return residuals end -@views function Φ( +function Φ( fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int, singular_term, ::NoDiffCacheNeeded ) (; c, v, x, b) = TU residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]] tmp = similar(fᵢ_cache) + N = length(tmp) T = eltype(u) for i in eachindex(k_discrete) - K = k_discrete[i] residᵢ = residuals[i] h = mesh_dt[i] @@ -158,15 +210,27 @@ end for r in 1:stage @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + @inbounds for j in 1:(r - 1) + Kⱼ = k_discrete[i][j] + xrj = h * x[r, j] + @simd ivdep for k in 1:N + tmp[k] += Kⱼ[k] * xrj + end + end t = mesh[i] + c[r] * h - K[:, r] .= f(tmp, p, t) - __add_singular_term!(K[:, r], singular_term, tmp, t) + k_discrete[i][r] .= f(tmp, p, t) + __add_singular_term!(k_discrete[i][r], singular_term, tmp, t) end # Update residual @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + @inbounds for j in 1:stage + Kⱼ = k_discrete[i][j] + mbhj = -h * b[j] + @simd ivdep for k in 1:N + residᵢ[k] += Kⱼ[k] * mbhj + end + end end return residuals diff --git a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl index 9c74bb8fb..1732427d2 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl @@ -104,7 +104,13 @@ end # state variables have their interpolation polynomials length_z = has_control ? length(cache.prob.f.f_prototype) : length(z) z .= zero(z) - __maybe_matmul!(z[1:length_z], k_discrete[i].du[1:length_z, 1:stage], w[1:stage]) + k_du = [dc.du for dc in k_discrete[i]] + if has_control + k_du_sub = [v[1:length_z] for v in k_du] + __maybe_matmul!(z[1:length_z], k_du_sub, w[1:stage]) + else + __maybe_matmul!(z[1:length_z], k_du, w[1:stage]) + end __maybe_matmul!( z[1:length_z], k_interp.u[i][1:length_z, 1:(s_star - stage)], w[(stage + 1):s_star], true, true @@ -132,7 +138,12 @@ end length_z = has_control ? length(cache.prob.f.f_prototype) : length(z) z .= zero(z) - __maybe_matmul!(z[1:length_z], k_discrete[i][1:length_z, 1:stage], w[1:stage]) + if has_control + k_sub = [v[1:length_z] for v in k_discrete[i]] + __maybe_matmul!(z[1:length_z], k_sub, w[1:stage]) + else + __maybe_matmul!(z[1:length_z], k_discrete[i], w[1:stage]) + end __maybe_matmul!( z[1:length_z], k_interp.u[i][1:length_z, 1:(s_star - stage)], w[(stage + 1):s_star], true, true @@ -159,7 +170,13 @@ end length_z = has_control ? length(cache.prob.f.f_prototype) : length(z′) z′ .= zero(z′) - __maybe_matmul!(z′[1:length_z], k_discrete[i].du[1:length_z, 1:stage], w′[1:stage]) + k_du = [dc.du for dc in k_discrete[i]] + if has_control + k_du_sub = [v[1:length_z] for v in k_du] + __maybe_matmul!(z′[1:length_z], k_du_sub, w′[1:stage]) + else + __maybe_matmul!(z′[1:length_z], k_du, w′[1:stage]) + end __maybe_matmul!( z′[1:length_z], k_interp.u[i][1:length_z, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true @@ -183,7 +200,12 @@ end length_z = has_control ? length(cache.prob.f.f_prototype) : length(z′) z′ .= zero(z′) - __maybe_matmul!(z′[1:length_z], k_discrete[i][1:length_z, 1:stage], w′[1:stage]) + if has_control + k_sub = [v[1:length_z] for v in k_discrete[i]] + __maybe_matmul!(z′[1:length_z], k_sub, w′[1:stage]) + else + __maybe_matmul!(z′[1:length_z], k_discrete[i], w′[1:stage]) + end __maybe_matmul!( z′[1:length_z], k_interp.u[i][1:length_z, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true @@ -215,8 +237,8 @@ function (s::EvalSol{C})(tval::Number) where {C <: MIRKCache} dt = cache.mesh_dt[ii] τ = (tval - t[ii]) / dt w, _ = evalsol_interp_weights(τ, alg) - K = __needs_diffcache(alg.jac_alg) ? @view(k_discrete[ii].du[:, 1:stage]) : - @view(k_discrete[ii][:, 1:stage]) + K = __needs_diffcache(alg.jac_alg) ? [dc.du for dc in k_discrete[ii]] : + k_discrete[ii] __maybe_matmul!(z[1:length_z], K, @view(w[1:stage])) # control variable just use linear interpolation @@ -245,8 +267,8 @@ function (s::EvalSol{C})(tvals::AbstractArray{<:Number}) where {C <: MIRKCache} dt = mesh_dt[ii] τ = (tval - t[ii]) / dt w, _ = evalsol_interp_weights(τ, alg) - K = __needs_diffcache(alg.jac_alg) ? @view(k_discrete[ii].du[:, 1:stage]) : - @view(k_discrete[ii][:, 1:stage]) + K = __needs_diffcache(alg.jac_alg) ? [dc.du for dc in k_discrete[ii]] : + k_discrete[ii] __maybe_matmul!(zvals[i][1:length_z], K, @view(w[1:stage])) # control variable just use linear interpolation @@ -268,7 +290,7 @@ function (s::EvalSol{C})(tval::Number, ::Type{Val{1}}) where {C <: MIRKCache} dt = mesh_dt[ii] τ = (tval - t[ii]) / dt _, w′ = interp_weights(τ, alg) - __maybe_matmul!(z′, @view(k_discrete[ii].du[:, 1:stage]), @view(w′[1:stage])) + __maybe_matmul!(z′, [dc.du for dc in k_discrete[ii]], @view(w′[1:stage])) return z′ end diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index ca49344d9..e92ffb747 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -22,9 +22,10 @@ y y₀ residual - # The following 2 caches are never resized + # The following 3 caches are never resized fᵢ_cache fᵢ₂_cache + y_cache # Pre-allocated get_tmp pointers for primal path errors new_stages resid_size @@ -66,6 +67,7 @@ function SciMLBase.__init( fᵢ_cache = __alloc(zero(X)) fᵢ₂_cache = vec(zero(X)) + y_cache = Vector{Vector{T}}(undef, Nig + 1) # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_guess_on_mesh(X, mesh, prob.p) @@ -78,13 +80,15 @@ function SciMLBase.__init( k_discrete = if !constraint [ - __maybe_allocate_diffcache(safe_similar(X, N, stage), chunksize, alg.jac_alg) - for _ in 1:Nig + [__maybe_allocate_diffcache(safe_similar(X, N), chunksize, alg.jac_alg) + for _ in 1:stage] + for _ in 1:Nig ] else [ - __maybe_allocate_diffcache(safe_similar(X, L_f_prototype, stage), chunksize, alg.jac_alg) - for _ in 1:Nig + [__maybe_allocate_diffcache(safe_similar(X, L_f_prototype), chunksize, alg.jac_alg) + for _ in 1:stage] + for _ in 1:Nig ] end k_interp = if !constraint @@ -231,7 +235,7 @@ function SciMLBase.__init( return MIRKCache{iip, T, use_both, typeof(diffcache), tune_parameters}( alg_order(alg), stage, N, size(X), f, bc, prob_, prob.problem_type, prob.p, alg, TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, - y₀, residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages, resid₁_size, prob.singular_term + y₀, residual, fᵢ_cache, fᵢ₂_cache, y_cache, errors, new_stages, resid₁_size, prob.singular_term , nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, tune_parameters, kwargs...), verbose_spec ) end @@ -251,6 +255,7 @@ function __expand_cache!(cache::MIRKCache{iip, T, use_both}) where {iip, T, use_ __resize!(cache.residual, Nₙ, cache.M) __resize!(cache.errors, ifelse(use_both, 2 * (Nₙ - 1), (Nₙ - 1)), cache.M) __resize!(cache.new_stages, Nₙ - 1, cache.M) + resize!(cache.y_cache, Nₙ) return cache end @@ -439,13 +444,12 @@ end resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache, EvalSol, trait::DiffCacheNeeded, constraint ) where {BC} - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait, constraint) - EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete - eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh) - recursive_flatten!(resid, resids) + y_ = recursive_unflatten!(y, cache.y_cache, u) + Φ!(residual[2:end], cache, y, u, trait, constraint) + copyto!(EvalSol.u, __restructure_sol(y_, cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) + eval_bc_residual!(get_tmp(residual[1], u), pt, bc!, EvalSol, p, mesh) + recursive_flatten!(resid, residual, u) return nothing end @@ -455,8 +459,8 @@ end ) where {BC} y_ = recursive_unflatten!(y, u) Φ!(residual[2:end], cache, y_, u, trait, constraint) - EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_, cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) eval_bc_residual!(residual[1], pt, bc!, EvalSol, p, mesh) recursive_flatten!(resid, residual) return nothing @@ -479,13 +483,13 @@ end resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, mesh, cache, _, trait::DiffCacheNeeded, constraint ) where {BC1, BC2} - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait, constraint) - resida = resids[1][1:prod(cache.resid_size[1])] - residb = resids[1][(prod(cache.resid_size[1]) + 1):end] + y_ = recursive_unflatten!(y, cache.y_cache, u) + Φ!(residual[2:end], cache, y, u, trait, constraint) + resid0 = get_tmp(residual[1], u) + resida = resid0[1:prod(cache.resid_size[1])] + residb = resid0[(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh) - recursive_flatten_twopoint!(resid, resids, cache.resid_size) + recursive_flatten_twopoint!(resid, residual, u, cache.resid_size) return nothing end @@ -516,8 +520,8 @@ end ) where {BC} y_ = recursive_unflatten!(y, u) resid_co = Φ(cache, y_, u, trait) - EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_, cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) end @@ -552,10 +556,10 @@ end @views function __mirk_loss_collocation!( resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded, constraint ) - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait, constraint) - recursive_flatten!(resid, resids) + recursive_unflatten!(y, cache.y_cache, u) + collocation_residual = residual[2:end] + Φ!(collocation_residual, cache, y, u, trait, constraint) + recursive_flatten!(resid, collocation_residual, u) return nothing end @@ -563,9 +567,9 @@ end resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded, constraint ) y_ = recursive_unflatten!(y, u) - resids = [r for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait, constraint) - recursive_flatten!(resid, resids) + collocation_residual = residual[2:end] + Φ!(collocation_residual, cache, y_, u, trait, constraint) + recursive_flatten!(resid, collocation_residual) return nothing end diff --git a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl index 88f50ecb8..20925ec94 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl @@ -385,10 +385,10 @@ end y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[3:end], cache, y_, u, p) - EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete - EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size) - EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) + copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)) + copyto!(EvalDSol.cache.k_discrete, cache.k_discrete) eval_bc_residual!(resids[1:2], pt, bc, EvalSol, EvalDSol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -400,10 +400,10 @@ end ) where {BC} y_ = recursive_unflatten!(y, u) resid_co = Φ(cache, y_, u, p) - EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete - EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size) - EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) + copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)) + copyto!(EvalDSol.cache.k_discrete, cache.k_discrete) resid_bc = eval_bc_residual(pt, bc, EvalSol, EvalDSol, p, mesh) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) end diff --git a/test/misc/allocation_tests.jl b/test/misc/allocation_tests.jl new file mode 100644 index 000000000..d51104942 --- /dev/null +++ b/test/misc/allocation_tests.jl @@ -0,0 +1,77 @@ +@testitem "MIRK Loss Function Allocations" tags=[:allocs] begin + using BoundaryValueDiffEq, BoundaryValueDiffEqMIRK, BoundaryValueDiffEqCore, LinearAlgebra + + function f!(du, u, p, t) + du[1] = u[2] + du[2] = -u[1] + return nothing + end + + function bc!(resid, sol, p, t) + resid[1] = sol(0.0)[1] - 1.0 + resid[2] = sol(1.0)[1] - cos(1.0) + return nothing + end + + function tpbc_a!(resid, ua, p) + resid[1] = ua[1] - 1.0 + return nothing + end + + function tpbc_b!(resid, ub, p) + resid[1] = ub[1] - cos(1.0) + return nothing + end + + u0 = [1.0, 0.0] + tspan = (0.0, 1.0) + + bvp = BVProblem(BVPFunction{true}(f!, bc!; bcresid_prototype = zeros(2)), u0, tspan) + tpbvp = BVProblem( + BVPFunction{true}(f!, (tpbc_a!, tpbc_b!); + bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true)), + u0, tspan) + + # Test that the loss function allocations scale sub-linearly with mesh size + # (i.e., per-step allocations are bounded, not proportional to mesh points) + for (name, prob) in [("StandardBVP", bvp), ("TwoPointBVP", tpbvp)] + for alg in [MIRK4(), MIRK5(), MIRK6()] + cache = SciMLBase.__init(prob, alg; dt = 0.1, adaptive = false) + nlprob = BoundaryValueDiffEqMIRK.__construct_problem( + cache, vec(cache.y₀), copy(cache.y₀)) + + u_test = copy(nlprob.u0) + resid_test = zeros(length(nlprob.u0)) + + # Warmup + nlprob.f(resid_test, u_test, nlprob.p) + + # Measure allocations per loss call + allocs = @allocated nlprob.f(resid_test, u_test, nlprob.p) + + # Loss function should allocate less than 10 KiB per call + # (the remaining allocations are from SubArray views in the inner loop + # which scale with mesh size but are small per-element) + @test allocs < 10 * 1024 # 10 KiB threshold + end + end + + # Test that non-adaptive solve allocations are bounded + for alg in [MIRK4(), MIRK5()] + # Small mesh + sol_small = solve(bvp, alg; dt = 0.1, adaptive = false) + @test sol_small.retcode == ReturnCode.Success + allocs_small = @allocated solve(bvp, alg; dt = 0.1, adaptive = false) + + # Larger mesh (5x) + sol_large = solve(bvp, alg; dt = 0.02, adaptive = false) + @test sol_large.retcode == ReturnCode.Success + allocs_large = @allocated solve(bvp, alg; dt = 0.02, adaptive = false) + + # Allocations should scale much less than 5x + # (ideally close to linear with mesh size due to Jacobian setup, + # but per-Newton-step allocations should be small) + ratio = allocs_large / allocs_small + @test ratio < 10 # Should be well under 10x for 5x more mesh points + end +end diff --git a/test/qa/runtests.jl b/test/qa/runtests.jl index b67625142..06f23b4b6 100644 --- a/test/qa/runtests.jl +++ b/test/qa/runtests.jl @@ -1,4 +1,5 @@ -using BoundaryValueDiffEq, Aqua, JET, Test, SciMLBase +using BoundaryValueDiffEq, BoundaryValueDiffEqMIRK, BoundaryValueDiffEqCore, + Aqua, JET, Test, SciMLBase @testset "Quality Assurance" begin @testset "Aqua" begin @@ -15,4 +16,40 @@ using BoundaryValueDiffEq, Aqua, JET, Test, SciMLBase ) @test length(JET.get_reports(rep)) == 0 end + + @testset "Zero per-step allocations in MIRK loss function" begin + function _f!(du, u, p, t) + du[1] = u[2] + du[2] = -u[1] + return nothing + end + function _bc!(resid, sol, p, t) + resid[1] = sol(0.0)[1] - 1.0 + resid[2] = sol(1.0)[1] - cos(1.0) + return nothing + end + u0 = [1.0, 0.0] + tspan = (0.0, 1.0) + bvp = BVProblem( + BVPFunction{true}(_f!, _bc!; bcresid_prototype = zeros(2)), u0, tspan) + + cache = SciMLBase.__init(bvp, MIRK4(); dt = 0.1, adaptive = false) + nlprob = BoundaryValueDiffEqMIRK.__construct_problem( + cache, vec(cache.y₀), copy(cache.y₀)) + + u_test = copy(nlprob.u0) + resid_test = zeros(length(nlprob.u0)) + p_test = nlprob.p + + # Verify loss function is allocation-free at runtime + function _bench_loss(f, resid, u, p, N) + for _ in 1:N + f(resid, u, p) + end + end + _bench_loss(nlprob.f, resid_test, u_test, p_test, 10) # warmup + stats = @timed _bench_loss(nlprob.f, resid_test, u_test, p_test, 10_000) + bytes_per_call = stats.bytes / 10_000 + @test bytes_per_call == 0.0 + end end