Skip to content

Commit 484d3af

Browse files
Run Runic formatting
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4e28332 commit 484d3af

5 files changed

Lines changed: 111 additions & 59 deletions

File tree

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,20 @@ end
3838
y -> begin
3939
if y isa Ref
4040
y = VectorOfArray(y[].u)
41-
end
41+
end
4242
# Return a plain Vector of arrays as gradient for `u`, not wrapped in VectorOfArray.
4343
# This avoids issues with downstream pullbacks that index into the gradient
4444
# using linear indexing (which now returns scalar elements for VectorOfArray).
4545
if y isa AbstractVectorOfArray
4646
(y.u,)
47-
else
48-
([
47+
else
48+
(
49+
[
4950
y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
5051
for i in 1:size(y)[end]
51-
],)
52-
end
52+
],
53+
)
54+
end
5355
end
5456
end
5557

@@ -63,17 +65,18 @@ end
6365
y -> begin
6466
if y isa Ref
6567
y = VectorOfArray(y[].u)
66-
end
68+
end
6769
if y isa AbstractVectorOfArray
6870
(y.u, nothing)
69-
else
70-
([
71+
else
72+
(
73+
[
7174
y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
7275
for i in 1:size(y)[end]
7376
],
7477
nothing,
7578
)
76-
end
79+
end
7780
end
7881
end
7982

lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ end
103103
function RaggedVectorOfArray(vec::AbstractVector)
104104
T = eltype(vec[1])
105105
N = ndims(vec[1])
106-
if all(x -> x isa Union{<:AbstractArray, <:AbstractVectorOfArray, <:AbstractRaggedVectorOfArray},
107-
vec)
106+
if all(
107+
x -> x isa Union{<:AbstractArray, <:AbstractVectorOfArray, <:AbstractRaggedVectorOfArray},
108+
vec
109+
)
108110
A = Vector{Union{typeof.(vec)...}}
109111
else
110112
A = typeof(vec)
@@ -240,8 +242,10 @@ function VectorOfArray(r::AbstractRaggedVectorOfArray{T, N, A}) where {T, N, A}
240242
end
241243

242244
function DiffEqArray(r::AbstractRaggedDiffEqArray)
243-
return DiffEqArray(r.u, r.t, r.p, r.sys; discretes = r.discretes,
244-
interp = r.interp, dense = r.dense)
245+
return DiffEqArray(
246+
r.u, r.t, r.p, r.sys; discretes = r.discretes,
247+
interp = r.interp, dense = r.dense
248+
)
245249
end
246250

247251
# ═══════════════════════════════════════════════════════════════════════════════
@@ -409,8 +413,10 @@ function Base.getindex(r::RaggedVectorOfArray, ::Colon, ::Colon)
409413
end
410414

411415
function Base.getindex(r::RaggedDiffEqArray, ::Colon, ::Colon)
412-
return RaggedDiffEqArray(copy(r.u), copy(r.t), r.p, r.sys; discretes = r.discretes,
413-
interp = r.interp, dense = r.dense)
416+
return RaggedDiffEqArray(
417+
copy(r.u), copy(r.t), r.p, r.sys; discretes = r.discretes,
418+
interp = r.interp, dense = r.dense
419+
)
414420
end
415421

416422
# A[:, idx_array] returns a subset
@@ -425,8 +431,10 @@ function Base.getindex(
425431
r::RaggedDiffEqArray, ::Colon,
426432
I::Union{AbstractArray{Int}, AbstractArray{Bool}}
427433
)
428-
return RaggedDiffEqArray(r.u[I], r.t[I], r.p, r.sys; discretes = r.discretes,
429-
interp = r.interp, dense = r.dense)
434+
return RaggedDiffEqArray(
435+
r.u[I], r.t[I], r.p, r.sys; discretes = r.discretes,
436+
interp = r.interp, dense = r.dense
437+
)
430438
end
431439

432440
# A[j, :] returns a vector of the j-th component from each inner array
@@ -460,8 +468,10 @@ end
460468
function Base.getindex(
461469
r::RaggedDiffEqArray, ::Colon, I::AbstractRange
462470
)
463-
return RaggedDiffEqArray(r.u[I], r.t[I], r.p, r.sys;
464-
discretes = r.discretes, interp = r.interp, dense = r.dense)
471+
return RaggedDiffEqArray(
472+
r.u[I], r.t[I], r.p, r.sys;
473+
discretes = r.discretes, interp = r.interp, dense = r.dense
474+
)
465475
end
466476

467477
# ═══════════════════════════════════════════════════════════════════════════════
@@ -510,8 +520,10 @@ Base.@propagate_inbounds function _ragged_getindex(
510520
col_idxs = last(I) isa Colon ? eachindex(r.u) : last(I)
511521
if all(idx -> idx isa Colon, Base.front(I))
512522
u_slice = [r.u[col][Base.front(I)...] for col in col_idxs]
513-
return RaggedDiffEqArray(u_slice, r.t[col_idxs], r.p, r.sys;
514-
discretes = r.discretes, interp = r.interp, dense = r.dense)
523+
return RaggedDiffEqArray(
524+
u_slice, r.t[col_idxs], r.p, r.sys;
525+
discretes = r.discretes, interp = r.interp, dense = r.dense
526+
)
515527
else
516528
return [r.u[col][Base.front(I)...] for col in col_idxs]
517529
end
@@ -629,8 +641,10 @@ end
629641
function Base.zero(r::AbstractRaggedVectorOfArray)
630642
T = typeof(r)
631643
u_zero = [zero(u) for u in r.u]
632-
fields = [fname == :u ? u_zero : _ragged_copyfield(r, fname)
633-
for fname in fieldnames(T)]
644+
fields = [
645+
fname == :u ? u_zero : _ragged_copyfield(r, fname)
646+
for fname in fieldnames(T)
647+
]
634648
return T(fields...)
635649
end
636650

@@ -671,8 +685,10 @@ Base.sizehint!(r::AbstractRaggedVectorOfArray, i) = sizehint!(r.u, i)
671685
Base.reverse!(r::AbstractRaggedVectorOfArray) = (reverse!(r.u); r)
672686
Base.reverse(r::RaggedVectorOfArray) = RaggedVectorOfArray(reverse(r.u))
673687
function Base.reverse(r::RaggedDiffEqArray)
674-
return RaggedDiffEqArray(reverse(r.u), r.t, r.p, r.sys; discretes = r.discretes,
675-
interp = r.interp, dense = r.dense)
688+
return RaggedDiffEqArray(
689+
reverse(r.u), r.t, r.p, r.sys; discretes = r.discretes,
690+
interp = r.interp, dense = r.dense
691+
)
676692
end
677693

678694
function Base.copyto!(
@@ -732,8 +748,12 @@ _ragged_narrays(bc::Broadcast.Broadcasted) = __ragged_narrays(bc.args)
732748
function __ragged_narrays(args::Tuple)
733749
a = _ragged_narrays(args[1])
734750
b = __ragged_narrays(Base.tail(args))
735-
return a == 0 ? b : (b == 0 ? a : (a == b ? a :
736-
throw(DimensionMismatch("number of arrays must be equal"))))
751+
return a == 0 ? b : (
752+
b == 0 ? a : (
753+
a == b ? a :
754+
throw(DimensionMismatch("number of arrays must be equal"))
755+
)
756+
)
737757
end
738758
__ragged_narrays(args::Tuple{Any}) = _ragged_narrays(args[1])
739759
__ragged_narrays(::Tuple{}) = 0
@@ -812,7 +832,7 @@ end
812832

813833
function Base.show(io::IO, m::MIME"text/plain", r::AbstractRaggedVectorOfArray)
814834
println(io, summary(r), ':')
815-
show(io, m, r.u)
835+
return show(io, m, r.u)
816836
end
817837

818838
function Base.summary(r::AbstractRaggedVectorOfArray{T, N}) where {T, N}
@@ -824,7 +844,7 @@ function Base.show(io::IO, m::MIME"text/plain", r::AbstractRaggedDiffEqArray)
824844
show(io, m, r.t)
825845
println(io)
826846
print(io, "u: ")
827-
show(io, m, r.u)
847+
return show(io, m, r.u)
828848
end
829849

830850
function Base.summary(r::AbstractRaggedDiffEqArray{T, N}) where {T, N}
@@ -835,8 +855,10 @@ end
835855
# Callable interface — dense interpolation
836856
# ═══════════════════════════════════════════════════════════════════════════════
837857

838-
function (r::AbstractRaggedDiffEqArray)(t, ::Type{deriv} = Val{0};
839-
idxs = nothing, continuity = :left) where {deriv}
858+
function (r::AbstractRaggedDiffEqArray)(
859+
t, ::Type{deriv} = Val{0};
860+
idxs = nothing, continuity = :left
861+
) where {deriv}
840862
r.interp === nothing &&
841863
error("No interpolation data is available. Provide an interpolation object via the `interp` keyword.")
842864
return r.interp(t, idxs, deriv, r.p, continuity)

lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,10 @@ using Test
320320
@test r.dense == false
321321

322322
# With interp kwarg
323-
r2 = RaggedDiffEqArray([[1.0, 2.0], [3.0, 4.0, 5.0]], [0.0, 1.0];
324-
interp = :test_interp, dense = true)
323+
r2 = RaggedDiffEqArray(
324+
[[1.0, 2.0], [3.0, 4.0, 5.0]], [0.0, 1.0];
325+
interp = :test_interp, dense = true
326+
)
325327
@test r2.interp == :test_interp
326328
@test r2.dense == true
327329

src/vector_of_array.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,10 @@ function DiffEqArray(
449449
end
450450

451451
# first element representative
452-
function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p, sys;
453-
discretes = nothing, interp = nothing, dense = false)
452+
function DiffEqArray(
453+
vec::AbstractVector, ts::AbstractVector, p, sys;
454+
discretes = nothing, interp = nothing, dense = false
455+
)
454456
_size = size(vec[1])
455457
T = eltype(vec[1])
456458
return DiffEqArray{
@@ -548,8 +550,10 @@ end
548550
# SciMLBase's more-specific `(::AbstractODESolution)(t,...)` methods win dispatch
549551
# for solution objects and handle symbolic idxs, discrete params, etc.
550552

551-
function (da::AbstractDiffEqArray)(t, ::Type{deriv} = Val{0};
552-
idxs = nothing, continuity = :left) where {deriv}
553+
function (da::AbstractDiffEqArray)(
554+
t, ::Type{deriv} = Val{0};
555+
idxs = nothing, continuity = :left
556+
) where {deriv}
553557
da.interp === nothing &&
554558
error("No interpolation data is available. Provide an interpolation object via the `interp` keyword.")
555559
return da.interp(t, idxs, deriv, da.p, continuity)
@@ -775,9 +779,11 @@ Base.@propagate_inbounds function Base.setindex!(
775779
for d in 1:length(inner_I)
776780
if inner_I[d] > size(u_col, d)
777781
iszero(x) && return x
778-
throw(ArgumentError(
779-
"Cannot set non-zero value at index $ii: outside ragged storage bounds."
780-
))
782+
throw(
783+
ArgumentError(
784+
"Cannot set non-zero value at index $ii: outside ragged storage bounds."
785+
)
786+
)
781787
end
782788
end
783789
return u_col[CartesianIndex(inner_I)] = x
@@ -887,10 +893,12 @@ Base.@propagate_inbounds function Base.setindex!(
887893
for d in 1:length(inner_I)
888894
if inner_I[d] > size(u_col, d)
889895
iszero(v) && return v
890-
throw(ArgumentError(
891-
"Cannot set non-zero value at index $I: outside ragged storage bounds. " *
892-
"Inner array $col has size $(size(u_col)) but index requires $(inner_I)."
893-
))
896+
throw(
897+
ArgumentError(
898+
"Cannot set non-zero value at index $I: outside ragged storage bounds. " *
899+
"Inner array $col has size $(size(u_col)) but index requires $(inner_I)."
900+
)
901+
)
894902
end
895903
end
896904
return u_col[inner_I...] = v
@@ -906,7 +914,7 @@ Base.@propagate_inbounds function Base.getindex(
906914
inner_I = Base.front(I)
907915
u_col = A.u[col]
908916
# Return zero for indices outside ragged storage
909-
for d in 1:N - 1
917+
for d in 1:(N - 1)
910918
if inner_I[d] > size(u_col, d)
911919
return zero(T)
912920
end
@@ -1089,8 +1097,10 @@ end
10891097
return Array{T}(undef, dims...)
10901098
end
10911099
@inline function Base.similar(
1092-
VA::AbstractVectorOfArray, ::Type{T}, dims::Tuple{Union{Integer, Base.OneTo},
1093-
Vararg{Union{Integer, Base.OneTo}}}
1100+
VA::AbstractVectorOfArray, ::Type{T}, dims::Tuple{
1101+
Union{Integer, Base.OneTo},
1102+
Vararg{Union{Integer, Base.OneTo}},
1103+
}
10941104
) where {T}
10951105
return similar(Array{T}, dims)
10961106
end
@@ -1361,12 +1371,16 @@ function solplot_vecs_and_labels(dims, vars, plott, A)
13611371
return plot_vecs, labels
13621372
end
13631373

1364-
@recipe function f(VA::AbstractDiffEqArray;
1365-
denseplot = (hasproperty(VA, :dense) && VA.dense &&
1366-
hasproperty(VA, :interp) && VA.interp !== nothing),
1374+
@recipe function f(
1375+
VA::AbstractDiffEqArray;
1376+
denseplot = (
1377+
hasproperty(VA, :dense) && VA.dense &&
1378+
hasproperty(VA, :interp) && VA.interp !== nothing
1379+
),
13671380
plotdensity = max(1000, 10 * length(VA.u)),
13681381
tspan = nothing, plotat = nothing,
1369-
idxs = nothing)
1382+
idxs = nothing
1383+
)
13701384

13711385
idxs_input = idxs === nothing ? plottable_indices(VA.u[1]) : idxs
13721386
if !(idxs_input isa Union{Tuple, AbstractArray})
@@ -1397,8 +1411,10 @@ end
13971411
end
13981412

13991413
# Default xguide for time-vs-variable plots
1400-
if all(x -> (x[2] isa Integer && x[2] == 0) ||
1401-
isequal(x[2], getindepsym_defaultt(VA)), vars)
1414+
if all(
1415+
x -> (x[2] isa Integer && x[2] == 0) ||
1416+
isequal(x[2], getindepsym_defaultt(VA)), vars
1417+
)
14021418
xguide --> "$(getindepsym_defaultt(VA))"
14031419
if tspan === nothing
14041420
if tdir > 0

test/runtests.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using Pkg
22
# Install the ShorthandConstructors subpackage for tests that need VA[...]/AP[...] syntax
3-
Pkg.develop(PackageSpec(
4-
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsShorthandConstructors")))
3+
Pkg.develop(
4+
PackageSpec(
5+
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsShorthandConstructors")
6+
)
7+
)
58
using RecursiveArrayTools
69
using Test
710
using SafeTestsets
@@ -17,8 +20,11 @@ end
1720
function activate_gpu_env()
1821
Pkg.activate("gpu")
1922
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
20-
Pkg.develop(PackageSpec(
21-
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsArrayPartitionAnyAll")))
23+
Pkg.develop(
24+
PackageSpec(
25+
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsArrayPartitionAnyAll")
26+
)
27+
)
2228
return Pkg.instantiate()
2329
end
2430

@@ -49,8 +55,11 @@ end
4955

5056
if GROUP == "Subpackages" || GROUP == "All"
5157
# Test that loading RecursiveArrayToolsArrayPartitionAnyAll overrides any/all
52-
Pkg.develop(PackageSpec(
53-
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsArrayPartitionAnyAll")))
58+
Pkg.develop(
59+
PackageSpec(
60+
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsArrayPartitionAnyAll")
61+
)
62+
)
5463
@time @safetestset "ArrayPartition AnyAll Subpackage" begin
5564
using RecursiveArrayTools, RecursiveArrayToolsArrayPartitionAnyAll, Test
5665
# Verify optimized methods are active

0 commit comments

Comments
 (0)