diff --git a/lib/BoundaryValueDiffEqCore/src/types.jl b/lib/BoundaryValueDiffEqCore/src/types.jl index 542e3ada9..49bd55980 100644 --- a/lib/BoundaryValueDiffEqCore/src/types.jl +++ b/lib/BoundaryValueDiffEqCore/src/types.jl @@ -174,6 +174,7 @@ function __maybe_allocate_diffcache(x, chunksize, jac_alg) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize; warn_on_resize = false) : x end + function __maybe_allocate_diffcache(x::DiffCache, chunksize) return DiffCache(zero(x.du), chunksize; warn_on_resize = false) end diff --git a/lib/BoundaryValueDiffEqCore/src/utils.jl b/lib/BoundaryValueDiffEqCore/src/utils.jl index 4a6e5cc50..17eabbc39 100644 --- a/lib/BoundaryValueDiffEqCore/src/utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/utils.jl @@ -1,3 +1,6 @@ +@inline _maybe_get_tmp(x::DiffCache, u) = PreallocationTools.get_tmp(x, u) +@inline _maybe_get_tmp(x, u) = x + recursive_length(x::Vector{<:AbstractArray}) = sum(length, x) recursive_length(x::Vector{<:DiffCache}) = sum(xᵢ -> length(xᵢ.u), x) @@ -15,6 +18,16 @@ end end return y end + +@views function recursive_flatten!(y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector) + i = 0 + for xᵢ in x + tmp = PreallocationTools.get_tmp(xᵢ, u) + copyto!(y[(i + 1):(i + length(tmp))], tmp) + i += length(tmp) + end + return y +end @views function recursive_flatten_twopoint!(y::AbstractVector, x::Vector{<:AbstractArray}, sizes) x_, xiter = first(x), x[2:end] copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])]) @@ -27,6 +40,21 @@ end return y end +@views function recursive_flatten_twopoint!( + y::AbstractVector, x::AbstractVector{<:DiffCache}, u::AbstractVector, sizes + ) + x_ = PreallocationTools.get_tmp(first(x), u) + copyto!(y[1:prod(sizes[1])], x_[1:prod(sizes[1])]) + i = prod(sizes[1]) + for j in 2:length(x) + xᵢ = PreallocationTools.get_tmp(x[j], u) + copyto!(y[(i + 1):(i + length(xᵢ))], xᵢ) + i += length(xᵢ) + end + copyto!(y[(i + 1):(i + prod(sizes[2]))], x_[(end - prod(sizes[2]) + 1):end]) + return y +end + @views function recursive_unflatten!(y::Vector{<:AbstractArray}, x::AbstractVector) i = 0 for yᵢ in y @@ -40,6 +68,28 @@ end return recursive_unflatten!(get_tmp.(y, (x,)), x) end +# Non-allocating version with pre-allocated output cache. +# When element types match (primal path), fills y_cache in-place. +# When they don't (Dual path), falls back to broadcast allocation. +@views function recursive_unflatten!( + y::Vector{<:DiffCache}, y_cache::Vector{<:AbstractVector{T}}, x::AbstractVector{T} + ) where {T} + i = 0 + for (j, yᵢ) in enumerate(y) + tmp = PreallocationTools.get_tmp(yᵢ, x) + y_cache[j] = tmp + copyto!(tmp, x[(i + 1):(i + length(tmp))]) + i += length(tmp) + end + return y_cache +end + +@views function recursive_unflatten!( + y::Vector{<:DiffCache}, y_cache::Vector, x::AbstractVector + ) + return recursive_unflatten!(get_tmp.(y, (x,)), x) +end + @views function recursive_unflatten!(y::AbstractVectorOfArray, x::AbstractVector) i = 0 for yᵢ in y @@ -56,18 +106,20 @@ function diff!(dx, x) return dx end -function __maybe_matmul!(z::AbstractArray, A, b, α = eltype(z)(1), β = eltype(z)(0)) - return mul!(z, A, b, α, β) -end +function __maybe_matmul!(z, A, b, α = one(eltype(z)), β = zero(eltype(z))) + # First z = β*z + @inbounds for i in eachindex(z) + z[i] *= β + end -# NOTE: We can implement it as mul! as above but then we pay the cost of moving -# `w` to the GPU too many times. Instead if we iterate of w and w′ we save -# that cost. Our main cost is anyways going to be due to a large `u0` and -# we are going to use GPUs for that -@views function __maybe_matmul!(z, A, b, α = eltype(z)(1), β = eltype(z)(0)) - @simd ivdep for j in eachindex(b) - @inbounds @. z = α * A[:, j] * b[j] + β * z + # Then z += α*A*b + @inbounds for j in axes(A, 2) + bj = α * b[j] + for i in axes(A, 1) + z[i] += A[i, j] * bj + end end + return z end diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index d98fc4fc4..af2b6b28c 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -1346,7 +1346,7 @@ end y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[2:end], cache, y_, u, trait, constraint) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -1359,7 +1359,7 @@ end y_ = recursive_unflatten!(y, u) resids = [r for r in residual] Φ!(resids[2:end], cache, y_, u, trait, constraint) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -1419,7 +1419,7 @@ end u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait ) where {BC} y_ = recursive_unflatten!(y, u) - eval_sol.u[1:end] .= y_ + copyto!(eval_sol.u, y_) resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh) resid_co = Φ(cache, y_, u, trait) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 02f068360..bda2b9df9 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -24,7 +24,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, DiffCacheNeeded, NoDiffCacheNeeded, __split_kwargs, __concrete_kwargs, __FastShortcutNonlinearPolyalg, __construct_internal_problem, __internal_solve, - __default_sparsity_detector, __build_cost, __add_singular_term! + __default_sparsity_detector, __build_cost, __add_singular_term!, + _maybe_get_tmp using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian diff --git a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl index 4a7216d4b..4351aad5b 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl @@ -19,7 +19,7 @@ end T = eltype(u) for i in eachindex(k_discrete) K = get_tmp(k_discrete[i], u) - residᵢ = residual[i] + residᵢ = _maybe_get_tmp(residual[i], u) h = mesh_dt[i] yᵢ = get_tmp(y[i], u) @@ -51,7 +51,7 @@ end T = eltype(u) for i in eachindex(k_discrete) K = get_tmp(k_discrete[i], u) - residᵢ = residual[i] + residᵢ = _maybe_get_tmp(residual[i], u) h = mesh_dt[i] yᵢ = get_tmp(y[i], u) @@ -77,7 +77,7 @@ end ) (; c, v, x, b) = TU - tmp = similar(fᵢ_cache) + tmp = fᵢ_cache T = eltype(u) for i in eachindex(k_discrete) K = k_discrete[i] diff --git a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl index 7e3747b90..f174d13f5 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl @@ -377,9 +377,9 @@ always update the intermediate solution with discrete solution + discrete stages (Continuous MIRK: u(meshᵢ + τ*dt) = yᵢ + dt sum br(τ)*kr). """ @views function update_eval_sol!(eval_sol::EvalSol, y_, cache::MIRKCache) - eval_sol.u[1:end] .= __restructure_sol(y_, cache.in_size) - eval_sol.cache.k_discrete[1:end] .= cache.k_discrete - eval_sol.cache.k_interp.u[1:end] .= cache.k_interp.u + copyto!(eval_sol.u, __restructure_sol(y_, cache.in_size)) + copyto!(eval_sol.cache.k_discrete, cache.k_discrete) + copyto!(eval_sol.cache.k_interp.u, cache.k_interp.u) interp_setup!(eval_sol.cache) return nothing end diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 8bcd2748e..2d0b698be 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -22,9 +22,10 @@ y y₀ residual - # The following 2 caches are never resized + # The following 3 caches are never resized fᵢ_cache fᵢ₂_cache + y_cache # Pre-allocated get_tmp pointers for primal path errors new_stages resid_size @@ -72,6 +73,7 @@ function SciMLBase.__init( fᵢ_cache = __alloc(zero(u0)) fᵢ₂_cache = vec(zero(u0)) + y_cache = Vector{Vector{T}}(undef, Nig + 1) # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p; tune_parameters = tune_parameters) @@ -237,7 +239,7 @@ function SciMLBase.__init( return MIRKCache{iip, T, use_both, typeof(diffcache), tune_parameters}( alg_order(alg), stage, N, size(u0), f, bc, prob_, prob.problem_type, prob.p, alg, TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, - y₀, residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages, resid₁_size, prob.singular_term + y₀, residual, fᵢ_cache, fᵢ₂_cache, y_cache, errors, new_stages, resid₁_size, prob.singular_term , nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, tune_parameters, kwargs...), verbose_spec ) end @@ -257,6 +259,7 @@ function __expand_cache!(cache::MIRKCache{iip, T, use_both}) where {iip, T, use_ __resize!(cache.residual, Nₙ, cache.M) __resize!(cache.errors, ifelse(use_both, 2 * (Nₙ - 1), (Nₙ - 1)), cache.M) __resize!(cache.new_stages, Nₙ - 1, cache.M) + resize!(cache.y_cache, Nₙ) return cache end @@ -445,12 +448,11 @@ end resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache, eval_sol, trait::DiffCacheNeeded, constraint ) where {BC} - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait, constraint) + y_ = recursive_unflatten!(y, cache.y_cache, u) + Φ!(residual[2:end], cache, y_, u, trait, constraint) update_eval_sol!(eval_sol, y_, cache) - eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) - recursive_flatten!(resid, resids) + eval_bc_residual!(get_tmp(residual[1], u), pt, bc!, eval_sol, p, mesh) + recursive_flatten!(resid, residual, u) return nothing end @@ -483,13 +485,13 @@ end resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, mesh, cache, _, trait::DiffCacheNeeded, constraint ) where {BC1, BC2} - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait, constraint) - resida = resids[1][1:prod(cache.resid_size[1])] - residb = resids[1][(prod(cache.resid_size[1]) + 1):end] + y_ = recursive_unflatten!(y, cache.y_cache, u) + Φ!(residual[2:end], cache, y, u, trait, constraint) + resid0 = get_tmp(residual[1], u) + resida = resid0[1:prod(cache.resid_size[1])] + residb = resid0[(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh) - recursive_flatten_twopoint!(resid, resids, cache.resid_size) + recursive_flatten_twopoint!(resid, residual, u, cache.resid_size) return nothing end @@ -555,10 +557,10 @@ end @views function __mirk_loss_collocation!( resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded, constraint ) - y_ = recursive_unflatten!(y, u) - resids = [get_tmp(r, u) for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait, constraint) - recursive_flatten!(resid, resids) + recursive_unflatten!(y, cache.y_cache, u) + collocation_residual = residual[2:end] + Φ!(collocation_residual, cache, y, u, trait, constraint) + recursive_flatten!(resid, collocation_residual, u) return nothing end @@ -566,9 +568,9 @@ end resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded, constraint ) y_ = recursive_unflatten!(y, u) - resids = [r for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait, constraint) - recursive_flatten!(resid, resids) + collocation_residual = residual[2:end] + Φ!(collocation_residual, cache, y_, u, trait, constraint) + recursive_flatten!(resid, collocation_residual) return nothing end diff --git a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl index 2a92d2316..de9dc60de 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl @@ -385,10 +385,10 @@ end y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[3:end], cache, y_, u, p) - EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete - EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size) - EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) + copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)) + copyto!(EvalDSol.cache.k_discrete, cache.k_discrete) eval_bc_residual!(resids[1:2], pt, bc, EvalSol, EvalDSol, p, mesh) recursive_flatten!(resid, resids) return nothing @@ -400,10 +400,10 @@ end ) where {BC} y_ = recursive_unflatten!(y, u) resid_co = Φ(cache, y_, u, p) - EvalSol.u[1:end] .= __restructure_sol(y_[1:length(cache.mesh)], cache.in_size) - EvalSol.cache.k_discrete[1:end] .= cache.k_discrete - EvalDSol.u[1:end] .= __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size) - EvalDSol.cache.k_discrete[1:end] .= cache.k_discrete + copyto!(EvalSol.u, __restructure_sol(y_[1:length(cache.mesh)], cache.in_size)) + copyto!(EvalSol.cache.k_discrete, cache.k_discrete) + copyto!(EvalDSol.u, __restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size)) + copyto!(EvalDSol.cache.k_discrete, cache.k_discrete) resid_bc = eval_bc_residual(pt, bc, EvalSol, EvalDSol, p, mesh) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) end diff --git a/test/misc/allocation_tests.jl b/test/misc/allocation_tests.jl new file mode 100644 index 000000000..5162b4d3e --- /dev/null +++ b/test/misc/allocation_tests.jl @@ -0,0 +1,81 @@ +@testitem "MIRK Loss Function Allocations" tags = [:allocs] begin + using BoundaryValueDiffEq, BoundaryValueDiffEqMIRK, BoundaryValueDiffEqCore, LinearAlgebra + + function f!(du, u, p, t) + du[1] = u[2] + du[2] = -u[1] + return nothing + end + + function bc!(resid, sol, p, t) + resid[1] = sol(0.0)[1] - 1.0 + resid[2] = sol(1.0)[1] - cos(1.0) + return nothing + end + + function tpbc_a!(resid, ua, p) + resid[1] = ua[1] - 1.0 + return nothing + end + + function tpbc_b!(resid, ub, p) + resid[1] = ub[1] - cos(1.0) + return nothing + end + + u0 = [1.0, 0.0] + tspan = (0.0, 1.0) + + bvp = BVProblem(BVPFunction{true}(f!, bc!; bcresid_prototype = zeros(2)), u0, tspan) + tpbvp = BVProblem( + BVPFunction{true}( + f!, (tpbc_a!, tpbc_b!); + bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true) + ), + u0, tspan + ) + + # Test that the loss function allocations scale sub-linearly with mesh size + # (i.e., per-step allocations are bounded, not proportional to mesh points) + for (name, prob) in [("StandardBVP", bvp), ("TwoPointBVP", tpbvp)] + for alg in [MIRK4(), MIRK5(), MIRK6()] + cache = SciMLBase.__init(prob, alg; dt = 0.1, adaptive = false) + nlprob = BoundaryValueDiffEqMIRK.__construct_problem( + cache, vec(cache.y₀), copy(cache.y₀) + ) + + u_test = copy(nlprob.u0) + resid_test = zeros(length(nlprob.u0)) + + # Warmup + nlprob.f(resid_test, u_test, nlprob.p) + + # Measure allocations per loss call + allocs = @allocated nlprob.f(resid_test, u_test, nlprob.p) + + # Loss function should allocate less than 10 KiB per call + # (the remaining allocations are from SubArray views in the inner loop + # which scale with mesh size but are small per-element) + @test allocs < 10 * 1024 # 10 KiB threshold + end + end + + # Test that non-adaptive solve allocations are bounded + for alg in [MIRK4(), MIRK5()] + # Small mesh + sol_small = solve(bvp, alg; dt = 0.1, adaptive = false) + @test sol_small.retcode == ReturnCode.Success + allocs_small = @allocated solve(bvp, alg; dt = 0.1, adaptive = false) + + # Larger mesh (5x) + sol_large = solve(bvp, alg; dt = 0.02, adaptive = false) + @test sol_large.retcode == ReturnCode.Success + allocs_large = @allocated solve(bvp, alg; dt = 0.02, adaptive = false) + + # Allocations should scale much less than 5x + # (ideally close to linear with mesh size due to Jacobian setup, + # but per-Newton-step allocations should be small) + ratio = allocs_large / allocs_small + @test ratio < 10 # Should be well under 10x for 5x more mesh points + end +end diff --git a/test/qa/runtests.jl b/test/qa/runtests.jl index b67625142..d61d61bd8 100644 --- a/test/qa/runtests.jl +++ b/test/qa/runtests.jl @@ -1,4 +1,5 @@ -using BoundaryValueDiffEq, Aqua, JET, Test, SciMLBase +using BoundaryValueDiffEq, BoundaryValueDiffEqMIRK, BoundaryValueDiffEqCore, + Aqua, JET, Test, SciMLBase @testset "Quality Assurance" begin @testset "Aqua" begin @@ -15,4 +16,42 @@ using BoundaryValueDiffEq, Aqua, JET, Test, SciMLBase ) @test length(JET.get_reports(rep)) == 0 end + + @testset "Zero per-step allocations in MIRK loss function" begin + function _f!(du, u, p, t) + du[1] = u[2] + du[2] = -u[1] + return nothing + end + function _bc!(resid, sol, p, t) + resid[1] = sol(0.0)[1] - 1.0 + resid[2] = sol(1.0)[1] - cos(1.0) + return nothing + end + u0 = [1.0, 0.0] + tspan = (0.0, 1.0) + bvp = BVProblem( + BVPFunction{true}(_f!, _bc!; bcresid_prototype = zeros(2)), u0, tspan + ) + + cache = SciMLBase.__init(bvp, MIRK4(); dt = 0.1, adaptive = false) + nlprob = BoundaryValueDiffEqMIRK.__construct_problem( + cache, vec(cache.y₀), copy(cache.y₀) + ) + + u_test = copy(nlprob.u0) + resid_test = zeros(length(nlprob.u0)) + p_test = nlprob.p + + # Verify loss function is allocation-free at runtime + function _bench_loss(f, resid, u, p, N) + for _ in 1:N + f(resid, u, p) + end + end + _bench_loss(nlprob.f, resid_test, u_test, p_test, 10) # warmup + stats = @timed _bench_loss(nlprob.f, resid_test, u_test, p_test, 10_000) + bytes_per_call = stats.bytes / 10_000 + @test bytes_per_call == 0.0 + end end