Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module RecursiveArrayToolsRaggedArrays

import RecursiveArrayTools: RecursiveArrayTools, AbstractRaggedVectorOfArray,
AbstractRaggedDiffEqArray, VectorOfArray, DiffEqArray,
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved,
recursivefill!, recursivecopy!
using SymbolicIndexingInterface
using SymbolicIndexingInterface: ParameterTimeseriesCollection, ParameterIndexingProxy,
ScalarSymbolic, ArraySymbolic, NotSymbolic, Timeseries, SymbolCache
Expand Down Expand Up @@ -1519,6 +1520,26 @@ end

Base.map(f, A::AbstractRaggedVectorOfArray) = map(f, A.u)

# Named functor used by the nested-ragged mapreduce to ensure type-stable dispatch.
struct _RaggedMapReduce{F, Op}
f::F
op::Op
end
@inline (w::_RaggedMapReduce)(u) = mapreduce(w.f, w.op, u)

# When inner elements are themselves ragged, the view-based approach fails: view uses
# size(A.u[1]) for every column, causing BoundsErrors when inner shapes differ.
# We recurse element-by-element instead. Dispatching on the type of A.u (rather than
# using an if-check at runtime) keeps inference type-stable down to Julia 1.10.
function Base.mapreduce(
f, op,
A::AbstractRaggedVectorOfArray{T, N, <:AbstractVector{<:AbstractRaggedVectorOfArray}};
kwargs...
) where {T, N}
isempty(kwargs) || return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
return mapreduce(_RaggedMapReduce(f, op), op, A.u)
end

function Base.mapreduce(f, op, A::AbstractRaggedVectorOfArray; kwargs...)
return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
end
Expand Down Expand Up @@ -1725,4 +1746,39 @@ end
# Re-export has_discretes and get_discretes for the non-ragged types
has_discretes(::TT) where {TT <: AbstractDiffEqArray} = hasfield(TT, :discretes)

function recursivecopy!(b::AbstractRaggedVectorOfArray, a::AbstractRaggedVectorOfArray)
@inbounds for i in eachindex(b.u, a.u)
if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractRaggedVectorOfArray
recursivecopy!(b.u[i], a.u[i])
else
b.u[i] = copy(a.u[i])
end
end
return b
end

function recursivefill!(
b::AbstractRaggedVectorOfArray{T, N},
a::T2
) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N}
return fill!(b, a)
end

function recursivefill!(
b::AbstractRaggedVectorOfArray{T, N},
a::T2
) where {T <: StaticArraysCore.SArray, T2 <: Union{Number, Bool}, N}
@inbounds for arr in b.u, i in eachindex(arr)
arr[i] = map(_ -> a, arr[i])
end
return b
end

function recursivefill!(b::AbstractRaggedVectorOfArray{T, N}, a) where {T <: AbstractArray, N}
@inbounds for arr in b.u
recursivefill!(arr, a)
end
return b
end

end # module RecursiveArrayToolsRaggedArrays
41 changes: 41 additions & 0 deletions lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -986,4 +986,45 @@ using Test
@test rd isa RecursiveArrayTools.AbstractRaggedVectorOfArray
@test !(rd isa AbstractArray)
end

@testset "recursivefill! for RaggedVectorOfArray" begin
# Bool argument — the pattern used by ODE solver cache initialisation
r = RaggedVectorOfArray([ones(2), ones(3)])
recursivefill!(r, false)
@test r[:, 1] == [0.0, 0.0]
@test r[:, 2] == [0.0, 0.0, 0.0]

# Numeric argument
r2 = RaggedVectorOfArray([zeros(2), zeros(3)])
recursivefill!(r2, 1.0)
@test r2[:, 1] == [1.0, 1.0]
@test r2[:, 2] == [1.0, 1.0, 1.0]

# Ragged sizes are preserved
@test length(r[:, 1]) == 2
@test length(r[:, 2]) == 3
end

@testset "recursivecopy! for RaggedVectorOfArray" begin
src = RaggedVectorOfArray([ones(2), 2 * ones(3)])
dst = RaggedVectorOfArray([zeros(2), zeros(3)])
recursivecopy!(dst, src)
@test dst[:, 1] == [1.0, 1.0]
@test dst[:, 2] == [2.0, 2.0, 2.0]

# Verify deep copy — modifying src must not affect dst
src[:, 1] .= 99.0
@test dst[:, 1] == [1.0, 1.0]
end

@testset "mapreduce over nested ragged arrays" begin
# Outer array whose inner RaggedVoA elements have different column counts.
# mapreduce must recurse over A.u rather than building a fixed-shape view.
inner1 = RaggedVectorOfArray([ones(3), ones(3)]) # 2 columns
inner2 = RaggedVectorOfArray([ones(3), ones(3), ones(3)]) # 3 columns — ragged!
u = RaggedVectorOfArray([inner1, inner2])

@test mapreduce(identity, +, u) == 15.0 # (2+3)*3
end

end
Loading