Skip to content

Commit 5bb9c3d

Browse files
Add recursivecopyto! with copyto!-style linear-index semantics
`recursivecopy!` requires `ndims(b) == ndims(a)` via dispatch, matching its `copy!` namesake — which errors on mismatched axes. `recursivecopyto!` is the `copyto!` counterpart: linear-index copy with no shape check, only requiring `length(b) >= length(a)`. This unblocks the `reinit(prob; u0 = vec)` use case from #589 without weakening `recursivecopy!`'s contract. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 87c7a04 commit 5bb9c3d

5 files changed

Lines changed: 166 additions & 2 deletions

File tree

docs/src/recursive_array_functions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and do not require that the RecursiveArrayTools types are used.
88
```@docs
99
recursivecopy
1010
recursivecopy!
11+
recursivecopyto!
1112
vecvecapply
1213
copyat_or_push!
1314
```

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ module RecursiveArrayTools
189189
export DEFAULT_PLOT_FUNC, plottable_indices, plot_indices, getindepsym_defaultt,
190190
interpret_vars, add_labels!, diffeq_to_arrays, solplot_vecs_and_labels
191191

192-
export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
192+
export recursivecopy, recursivecopy!, recursivecopyto!, recursivefill!, vecvecapply,
193+
copyat_or_push!,
193194
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
194195
recursive_unitless_bottom_eltype, recursive_unitless_eltype
195196

src/array_partition.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,23 @@ function recursivecopy!(
361361
return A
362362
end
363363

364+
function recursivecopyto!(A::ArrayPartition, B::ArrayPartition)
365+
for (a, b) in zip(A.x, B.x)
366+
recursivecopyto!(a, b)
367+
end
368+
return A
369+
end
370+
371+
function recursivecopyto!(
372+
A::ArrayPartition{T, S},
373+
B::ArrayPartition{T, S}
374+
) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
375+
for i in eachindex(A.x, B.x)
376+
recursivecopyto!(A.x[i], B.x[i])
377+
end
378+
return A
379+
end
380+
364381
function recursive_mean(A::ArrayPartition)
365382
n = npartitions(A)
366383
if n == 0

src/utils.jl

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T, N})
4545
```
4646
4747
A recursive `copy!` function. Acts like a `deepcopy!` on arrays of arrays, but
48-
like `copy!` on arrays of scalars.
48+
like `copy!` on arrays of scalars. Requires `b` and `a` to have matching `ndims`;
49+
use [`recursivecopyto!`](@ref) for the `copyto!`-style linear-index variant that
50+
allows mismatched shapes.
4951
"""
5052
function recursivecopy! end
5153

@@ -105,6 +107,68 @@ function recursivecopy!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
105107
return b
106108
end
107109

110+
"""
111+
```julia
112+
recursivecopyto!(b::AbstractArray, a::AbstractArray)
113+
```
114+
115+
A recursive `copyto!` function. Acts like a `deepcopy!` on arrays of arrays, but
116+
like `copyto!` on arrays of scalars.
117+
118+
Unlike [`recursivecopy!`](@ref), this does not require `b` and `a` to have matching
119+
`ndims` or axes; only that `length(b) >= length(a)`. Elements are copied in linear
120+
(column-major) order, matching the semantics of `Base.copyto!`. Use this when
121+
flattening/reshaping between destination and source is intended, e.g. copying a
122+
`Vector` into a `Matrix` of the same total length.
123+
"""
124+
function recursivecopyto! end
125+
126+
function recursivecopyto!(b::AbstractArray, a::AbstractArray)
127+
return copyto!(b, a)
128+
end
129+
130+
function recursivecopyto!(
131+
b::AbstractArray{T},
132+
a::AbstractArray{T2}
133+
) where {
134+
T <: StaticArraysCore.StaticArray,
135+
T2 <: StaticArraysCore.StaticArray,
136+
}
137+
@inbounds for (ib, ia) in zip(eachindex(b), eachindex(a))
138+
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
139+
b[ib] = copy(a[ia])
140+
end
141+
return b
142+
end
143+
144+
function recursivecopyto!(
145+
b::AbstractArray{T},
146+
a::AbstractArray{T2}
147+
) where {
148+
T <: Union{AbstractArray, AbstractVectorOfArray},
149+
T2 <: Union{AbstractArray, AbstractVectorOfArray},
150+
}
151+
if ArrayInterface.ismutable(T)
152+
@inbounds for (ib, ia) in zip(eachindex(b), eachindex(a))
153+
recursivecopyto!(b[ib], a[ia])
154+
end
155+
else
156+
copyto!(b, a)
157+
end
158+
return b
159+
end
160+
161+
function recursivecopyto!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
162+
@inbounds for i in eachindex(b.u, a.u)
163+
if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractVectorOfArray
164+
recursivecopyto!(b.u[i], a.u[i])
165+
else
166+
b.u[i] = recursivecopy(a.u[i])
167+
end
168+
end
169+
return b
170+
end
171+
108172
"""
109173
```julia
110174
recursivefill!(b::AbstractArray{T, N}, a)

test/utils_test.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,87 @@ end
152152
@test a.u[1][1] == 1.0
153153
end
154154

155+
@testset "recursivecopyto!" begin
156+
# Same-shape scalar arrays — should match copyto!
157+
b = zeros(3)
158+
a = [1.0, 2.0, 3.0]
159+
recursivecopyto!(b, a)
160+
@test b == a
161+
162+
b = zeros(2, 2)
163+
a = [1.0 2.0; 3.0 4.0]
164+
recursivecopyto!(b, a)
165+
@test b == a
166+
167+
# Issue #589: Matrix ← Vector of matching length (rejected by recursivecopy!,
168+
# allowed by recursivecopyto!).
169+
b = zeros(2, 3)
170+
a = collect(1.0:6.0)
171+
recursivecopyto!(b, a)
172+
@test b == reshape(a, 2, 3)
173+
@test_throws MethodError recursivecopy!(b, a)
174+
175+
# Vector ← Matrix
176+
b = zeros(6)
177+
a = reshape(collect(1.0:6.0), 2, 3)
178+
recursivecopyto!(b, a)
179+
@test b == collect(1.0:6.0)
180+
181+
# Different-shape matrices, same total length
182+
b = zeros(2, 3)
183+
a = reshape(collect(1.0:6.0), 3, 2)
184+
recursivecopyto!(b, a)
185+
@test vec(b) == 1.0:6.0
186+
187+
# dst longer than src — tail untouched, matches Base.copyto!
188+
b = ones(5)
189+
a = [10.0, 20.0, 30.0]
190+
recursivecopyto!(b, a)
191+
@test b == [10.0, 20.0, 30.0, 1.0, 1.0]
192+
193+
# dst shorter than src — BoundsError, matches Base.copyto!
194+
b = zeros(2)
195+
a = [1.0, 2.0, 3.0]
196+
@test_throws BoundsError recursivecopyto!(b, a)
197+
198+
# Nested: Vector of Vectors, matching shapes
199+
a = [ones(3), 2 * ones(3)]
200+
b = [zeros(3), zeros(3)]
201+
recursivecopyto!(b, a)
202+
@test b[1] == ones(3) && b[2] == 2 * ones(3)
203+
# Verify deep copy semantics — mutating dst leaves src untouched
204+
b[1][1] = 99.0
205+
@test a[1][1] == 1.0
206+
207+
# Nested with shape mismatch at the leaves — inner copyto! handles it
208+
a = [collect(1.0:6.0), collect(7.0:12.0)]
209+
b = [zeros(2, 3), zeros(2, 3)]
210+
recursivecopyto!(b, a)
211+
@test b[1] == reshape(1.0:6.0, 2, 3)
212+
@test b[2] == reshape(7.0:12.0, 2, 3)
213+
214+
# Static array element
215+
a = [@SVector([1.0, 2.0]), @SVector([3.0, 4.0])]
216+
b = [@SVector(zeros(2)), @SVector(zeros(2))]
217+
recursivecopyto!(b, a)
218+
@test b == a
219+
220+
# ArrayPartition with matching shapes (sanity — parity with recursivecopy!)
221+
A = ArrayPartition(zeros(2), zeros(3))
222+
B = ArrayPartition([1.0, 2.0], [3.0, 4.0, 5.0])
223+
recursivecopyto!(A, B)
224+
@test A.x[1] == [1.0, 2.0]
225+
@test A.x[2] == [3.0, 4.0, 5.0]
226+
227+
# VectorOfArray
228+
u1 = VA[zeros(MVector{2, Float64}), zeros(MVector{2, Float64})]
229+
u2 = VA[fill(4, MVector{2, Float64}), 2 .* ones(MVector{2, Float64})]
230+
recursivecopyto!(u1, u2)
231+
@test u1.u[1] == [4.0, 4.0]
232+
@test u1.u[2] == [2.0, 2.0]
233+
@test u1.u[1] isa MVector
234+
end
235+
155236
@testset "VectorOfArray similar with nested scalar leaves" begin
156237
a = VA[ones(2), VA[1.0, 1.0]]
157238
b = similar(a, Float64)

0 commit comments

Comments
 (0)