Skip to content

Commit 8485bda

Browse files
Make AbstractVectorOfArray <: AbstractArray for proper interface compliance
This is a breaking change that completes the long-planned migration (documented since 2023) to make AbstractVectorOfArray a proper AbstractArray subtype. Key changes: - AbstractVectorOfArray{T,N,A} <: AbstractArray{T,N} - Linear indexing A[i] now returns the i-th element in column-major order (previously returned A.u[i], the i-th inner array) - size() uses maximum sizes across inner arrays for ragged data - Ragged out-of-bounds elements are treated as zero (sparse interpretation) - Iteration goes over scalar elements (AbstractArray default) - Removed deprecated linear indexing methods - Updated Zygote extension (some adjoints marked broken pending update) - Added parameter_values(::AbstractDiffEqArray, i) to fix dispatch ambiguity Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f7725f9 commit 8485bda

File tree

10 files changed

+450
-551
lines changed

10 files changed

+450
-551
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
pages = [
44
"Home" => "index.md",
5+
"breaking_changes_v4.md",
56
"AbstractVectorOfArrayInterface.md",
67
"array_types.md",
78
"recursive_array_functions.md",

docs/src/breaking_changes_v4.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Breaking Changes in v4.0: AbstractArray Interface
2+
3+
## Summary
4+
5+
`AbstractVectorOfArray{T, N, A}` now subtypes `AbstractArray{T, N}`. This means
6+
all `VectorOfArray` and `DiffEqArray` objects are proper Julia `AbstractArray`s,
7+
and all standard `AbstractArray` operations work out of the box, including linear
8+
algebra, broadcasting with plain arrays, and generic algorithms.
9+
10+
## Key Changes
11+
12+
### Linear Indexing
13+
14+
Previously, `A[i]` returned the `i`th inner array (`A.u[i]`). Now, `A[i]` returns
15+
the `i`th element in column-major linear order, matching standard Julia `AbstractArray`
16+
behavior.
17+
18+
```julia
19+
A = VectorOfArray([[1, 2], [3, 4]])
20+
# Old: A[1] == [1, 2] (first inner array)
21+
# New: A[1] == 1 (first element, column-major)
22+
# To access inner arrays: A.u[1] or A[:, 1]
23+
```
24+
25+
### Size and Ragged Arrays
26+
27+
For ragged arrays (inner arrays of different sizes), `size(A)` now reports the
28+
**maximum** size in each dimension. Out-of-bounds elements are treated as zero
29+
(sparse representation):
30+
31+
```julia
32+
A = VectorOfArray([[1, 2], [3, 4, 5]])
33+
size(A) # (3, 2) — max inner length is 3
34+
A[3, 1] # 0 — implicit zero (inner array 1 has only 2 elements)
35+
A[3, 2] # 5 — actual stored value
36+
Array(A) # [1 3; 2 4; 0 5] — zero-padded dense array
37+
```
38+
39+
This means ragged `VectorOfArray`s can be used directly with linear algebra
40+
operations, treating the data as a rectangular matrix with zero padding.
41+
42+
### Iteration
43+
44+
Iteration now goes over scalar elements in column-major order, matching
45+
`AbstractArray` behavior:
46+
47+
```julia
48+
A = VectorOfArray([[1, 2], [3, 4]])
49+
collect(A) # [1 3; 2 4] — 2x2 matrix
50+
# To iterate over inner arrays: for u in A.u ... end
51+
```
52+
53+
### `length`
54+
55+
`length(A)` now returns `prod(size(A))` (total number of elements including
56+
ragged zeros), not the number of inner arrays. Use `length(A.u)` for the number
57+
of inner arrays.
58+
59+
### `map`
60+
61+
`map(f, A)` now maps over individual elements, not inner arrays. Use
62+
`map(f, A.u)` to map over inner arrays.
63+
64+
### `first` / `last`
65+
66+
`first(A)` and `last(A)` return the first/last scalar element, not the first/last
67+
inner array. Use `first(A.u)` / `last(A.u)` for inner arrays.
68+
69+
### `eachindex`
70+
71+
`eachindex(A)` returns `CartesianIndices(size(A))` for the full rectangular shape,
72+
not indices into `A.u`.
73+
74+
## Migration Guide
75+
76+
| Old Code | New Code |
77+
|----------|----------|
78+
| `A[i]` (get inner array) | `A.u[i]` or `A[:, i]` |
79+
| `length(A)` (number of arrays) | `length(A.u)` |
80+
| `for elem in A` (iterate columns) | `for elem in A.u` |
81+
| `first(A)` (first inner array) | `first(A.u)` |
82+
| `map(f, A)` (map over columns) | `map(f, A.u)` |
83+
| `A == vec_of_vecs` | `A.u == vec_of_vecs` |
84+
85+
## Zygote Compatibility
86+
87+
Some Zygote adjoint rules need updating for the new `AbstractArray` subtyping.
88+
ForwardDiff continues to work correctly. Zygote support will be updated in a
89+
follow-up release.

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 3 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -188,219 +188,20 @@ end
188188
view(A, I...), view_adjoint
189189
end
190190

191-
@adjoint function Broadcast.broadcasted(
192-
::typeof(+), x::AbstractVectorOfArray,
193-
y::Union{Zygote.Numeric, AbstractVectorOfArray}
194-
)
195-
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
196-
end
197-
@adjoint function Broadcast.broadcasted(
198-
::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray
199-
)
200-
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
201-
end
191+
# Since AbstractVectorOfArray <: AbstractArray, Zygote's built-in AbstractArray
192+
# broadcast rules apply. We only keep specific overrides that don't conflict.
202193

203194
_minus(Δ) = .-Δ
204195
_minus(::Nothing) = nothing
205196

206-
@adjoint function Broadcast.broadcasted(
207-
::typeof(-), x::AbstractVectorOfArray,
208-
y::Union{AbstractVectorOfArray, Zygote.Numeric}
209-
)
210-
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
211-
end
212-
@adjoint function Broadcast.broadcasted(
213-
::typeof(*), x::AbstractVectorOfArray,
214-
y::Union{AbstractVectorOfArray, Zygote.Numeric}
215-
)
216-
(
217-
x .* y,
218-
Δ -> (
219-
nothing, Zygote.unbroadcast(x, Δ .* conj.(y)),
220-
Zygote.unbroadcast(y, Δ .* conj.(x)),
221-
),
222-
)
223-
end
224-
@adjoint function Broadcast.broadcasted(
225-
::typeof(/), x::AbstractVectorOfArray,
226-
y::Union{AbstractVectorOfArray, Zygote.Numeric}
227-
)
228-
res = x ./ y
229-
res,
230-
Δ -> (
231-
nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)),
232-
Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)),
233-
)
234-
end
235-
@adjoint function Broadcast.broadcasted(
236-
::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray
237-
)
238-
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
239-
end
240-
@adjoint function Broadcast.broadcasted(
241-
::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray
242-
)
243-
(
244-
x .* y,
245-
Δ -> (
246-
nothing, Zygote.unbroadcast(x, Δ .* conj.(y)),
247-
Zygote.unbroadcast(y, Δ .* conj.(x)),
248-
),
249-
)
250-
end
251-
@adjoint function Broadcast.broadcasted(
252-
::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray
253-
)
254-
res = x ./ y
255-
res,
256-
Δ -> (
257-
nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)),
258-
Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)),
259-
)
260-
end
261-
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray)
262-
.-x, Δ -> (nothing, _minus(Δ))
263-
end
264-
265-
@adjoint function Broadcast.broadcasted(
266-
::typeof(Base.literal_pow), ::typeof(^),
267-
x::AbstractVectorOfArray, exp::Val{p}
268-
) where {p}
269-
y = Base.literal_pow.(^, x, exp)
270-
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
271-
end
272-
273-
@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x,
274-
Δ -> (nothing, Δ)
275-
276-
@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray)
277-
y = tanh.(x)
278-
y, ȳ -> (nothing, ȳ .* conj.(1 .- y .^ 2))
279-
end
280-
281-
@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) = conj.(x),
282-
-> (nothing, conj.(z̄))
283-
284-
@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) = real.(x),
285-
-> (nothing, real.(z̄))
286-
287-
@adjoint Broadcast.broadcasted(
288-
::typeof(imag), x::AbstractVectorOfArray
289-
) = imag.(x),
290-
-> (nothing, im .* real.(z̄))
291-
292-
@adjoint Broadcast.broadcasted(
293-
::typeof(abs2),
294-
x::AbstractVectorOfArray
295-
) = abs2.(x),
296-
-> (nothing, 2 .* real.(z̄) .* x)
297-
298-
@adjoint function Broadcast.broadcasted(
299-
::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool
300-
)
301-
y = b === false ? a : a .+ b
302-
y, Δ -> (nothing, Δ, nothing)
303-
end
304-
@adjoint function Broadcast.broadcasted(
305-
::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number}
306-
)
307-
y = b === false ? a : b .+ a
308-
y, Δ -> (nothing, nothing, Δ)
309-
end
310-
311-
@adjoint function Broadcast.broadcasted(
312-
::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool
313-
)
314-
y = b === false ? a : a .- b
315-
y, Δ -> (nothing, Δ, nothing)
316-
end
317-
@adjoint function Broadcast.broadcasted(
318-
::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number}
319-
)
320-
b .- a, Δ -> (nothing, nothing, .-Δ)
321-
end
322-
323-
@adjoint function Broadcast.broadcasted(
324-
::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool
325-
)
326-
if b === false
327-
zero(a), Δ -> (nothing, zero(Δ), nothing)
328-
else
329-
a, Δ -> (nothing, Δ, nothing)
330-
end
331-
end
332-
@adjoint function Broadcast.broadcasted(
333-
::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number}
334-
)
335-
if b === false
336-
zero(a), Δ -> (nothing, nothing, zero(Δ))
337-
else
338-
a, Δ -> (nothing, nothing, Δ)
339-
end
340-
end
341-
342-
@adjoint Broadcast.broadcasted(
343-
::Type{T},
344-
x::AbstractVectorOfArray
345-
) where {
346-
T <:
347-
Number,
348-
} = T.(x),
349-
ȳ -> (nothing, Zygote._project(x, ȳ))
350-
351197
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
352198
N = ndims(x̄)
353199
return if length(x) == length(x̄)
354-
Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
200+
Zygote._project(x, x̄)
355201
else
356202
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄) + 1, ndims(x̄))
357203
Zygote._project(x, Zygote.accum_sum(x̄; dims = dims))
358204
end
359205
end
360206

361-
@adjoint Broadcast.broadcasted(
362-
::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray,
363-
b
364-
) where {F} = _broadcast_generic(
365-
__context__, f, a, b
366-
)
367-
@adjoint Broadcast.broadcasted(
368-
::Broadcast.AbstractArrayStyle, f::F, a,
369-
b::AbstractVectorOfArray
370-
) where {F} = _broadcast_generic(
371-
__context__, f, a, b
372-
)
373-
@adjoint Broadcast.broadcasted(
374-
::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray,
375-
b::AbstractVectorOfArray
376-
) where {F} = _broadcast_generic(
377-
__context__, f, a, b
378-
)
379-
380-
@inline function _broadcast_generic(__context__, f::F, args...) where {F}
381-
T = Broadcast.combine_eltypes(f, args)
382-
# Avoid generic broadcasting in two easy cases:
383-
if T == Bool
384-
return (f.(args...), _ -> nothing)
385-
elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) &&
386-
all(Zygote._dual_safearg, args) && !Zygote.isderiving()
387-
return Zygote.broadcast_forward(f, args...)
388-
end
389-
len = Zygote.inclen(args)
390-
y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...)
391-
y = broadcast(first, y∂b)
392-
function ∇broadcasted(ȳ)
393-
y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b
394-
ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten.u) : ȳ
395-
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
396-
getters = ntuple(i -> Zygote.StaticGetter{i}(), len)
397-
dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters)
398-
return (
399-
nothing, Zygote.accum_sum(dxs[1]),
400-
map(Zygote.unbroadcast, args, Base.tail(dxs))...,
401-
)
402-
end
403-
return y, ∇broadcasted
404-
end
405-
406207
end # module

src/RecursiveArrayTools.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ module RecursiveArrayTools
2626
2727
!!! note
2828
29-
In 2023 the linear indexing `A[i]` was deprecated. It previously had the behavior that `A[i] = A.u[i]`. However, this is incompatible with standard `AbstractArray` interfaces, Since if `A = VectorOfArray([[1,2],[3,4]])` and `A` is supposed to act like `[1 3; 2 4]`, then there is a difference `A[1] = [1,2]` for the VectorOfArray while `A[1] = 1` for the matrix. This causes many issues if `AbstractVectorOfArray <: AbstractArray`. Thus we plan in 2026 to complete the deprecation and thus have a breaking update where `A[i]` matches the linear indexing of an`AbstractArray`, and then making `AbstractVectorOfArray <: AbstractArray`. Until then, `AbstractVectorOfArray` due to
30-
this interface break but manually implements an `AbstractArray`-like interface for
31-
future compatibility.
29+
As of v4.0, `AbstractVectorOfArray <: AbstractArray`. Linear indexing `A[i]`
30+
now returns the `i`th element in column-major order, matching standard Julia
31+
`AbstractArray` behavior. To access the `i`th inner array, use `A.u[i]` or
32+
`A[:, i]`. For ragged arrays (inner arrays of different sizes), `size(A)`
33+
reports the maximum size in each dimension and out-of-bounds elements are
34+
interpreted as zero (sparse representation). This means all standard linear
35+
algebra operations work out of the box.
3236
3337
## Fields
3438
@@ -99,7 +103,7 @@ module RecursiveArrayTools
99103
to make the array type match the internal array type (for example, if `A` is an array
100104
of GPU arrays, `stack(A)` will be a GPU array).
101105
"""
102-
abstract type AbstractVectorOfArray{T, N, A} end
106+
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
103107

104108
"""
105109
AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A}

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function recursivefill!(
133133
T <: StaticArraysCore.StaticArray,
134134
T2 <: StaticArraysCore.StaticArray, N,
135135
}
136-
return @inbounds for b in bs, i in eachindex(b)
136+
return @inbounds for b in bs.u, i in eachindex(b)
137137

138138
b[i] = copy(a)
139139
end
@@ -159,7 +159,7 @@ function recursivefill!(
159159
T <: StaticArraysCore.SArray,
160160
T2 <: Union{Number, Bool}, N,
161161
}
162-
return @inbounds for b in bs, i in eachindex(b)
162+
return @inbounds for b in bs.u, i in eachindex(b)
163163

164164
# Preserve static array shape while replacing all entries with the scalar
165165
b[i] = map(_ -> a, b[i])

0 commit comments

Comments
 (0)