Skip to content

Commit f7725f9

Browse files
Merge pull request #546 from rokke-git/vectors
adds vector syntax for `ArrayPartition`s and `VectorOfArray`s
2 parents 19e9c4c + 334bd95 commit f7725f9

12 files changed

+193
-129
lines changed

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ vA = VectorOfArray(a)
3232
vB = VectorOfArray(b)
3333

3434
vA .* vB # Now all standard array stuff works!
35+
36+
# you can also create it directly with a vector-like syntax:
37+
c = VA[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
38+
d = VA[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
39+
40+
c .* d
3541
```
3642

3743
### ArrayPartition
@@ -44,15 +50,15 @@ pB = ArrayPartition(b)
4450

4551
pA .* pB # Now all standard array stuff works!
4652

47-
# or do:
53+
# or using the vector syntax:
4854
x0 = rand(3, 3)
4955
v0 = rand(3, 3)
5056
a0 = rand(3, 3)
51-
u0 = ArrayPartition(x0, v0, a0)
52-
u0.x[1] == x0 # true
57+
u0 = AP[x0, v0, a0]
58+
u0.x[1] === x0 # true
5359

5460
u0 .+= 1
55-
u0.x[2] == v0 # still true
61+
u0.x[2] === v0 # still true
5662

5763
# do some calculations creating a new partitioned array
5864
unew = u0 * 10

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ module RecursiveArrayTools
133133
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
134134
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))
135135

136-
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
136+
export VectorOfArray, VA, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
137137
AllObserved, vecarr_to_vectors, tuples
138138

139139
export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
140140
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
141141
recursive_unitless_bottom_eltype, recursive_unitless_eltype
142142

143-
export ArrayPartition, NamedArrayPartition
143+
export ArrayPartition, AP, NamedArrayPartition
144144

145145
include("precompilation.jl")
146146

src/array_partition.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,28 @@ end
722722
function Adapt.adapt_structure(to, ap::ArrayPartition)
723723
return ArrayPartition(map(x -> Adapt.adapt(to, x), ap.x)...)
724724
end
725+
726+
"""
727+
```julia
728+
AP[ matrices, ]
729+
```
730+
731+
Create an `ArrayPartition` using vector syntax. Equivalent to `ArrayPartition(matrices)`, but looks nicer with nesting.
732+
733+
# Examples:
734+
735+
Simple examples:
736+
```julia
737+
ArrayPartition([1,2,3], [1 2;3 4]) == AP[[1,2,3], [1 2;3 4]] # true
738+
AP[1u"m/s^2", 1u"m/s", 1u"m"]
739+
```
740+
741+
With an ODEProblem:
742+
```julia
743+
func(u, p, t) = AP[5u.x[1], u.x[2]./2]
744+
ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
745+
```
746+
747+
"""
748+
struct AP end
749+
Base.getindex(::Type{AP}, xs...) = ArrayPartition(xs...)

src/vector_of_array.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,3 +1655,31 @@ end
16551655
end
16561656
unpack_args_voa(i, args::Tuple{Any}) = (unpack_voa(args[1], i),)
16571657
unpack_args_voa(::Any, args::Tuple{}) = ()
1658+
1659+
"""
1660+
```julia
1661+
VA[ matrices, ]
1662+
```
1663+
1664+
Create an `VectorOfArray` using vector syntax. Equivalent to `VectorOfArray([matrices])`, but looks nicer with nesting.
1665+
1666+
# Simple example:
1667+
```julia
1668+
VectorOfArray([[1,2,3], [1 2;3 4]]) == VA[[1,2,3], [1 2;3 4]] # true
1669+
```
1670+
1671+
# All the layers:
1672+
```julia
1673+
nested = VA[
1674+
fill(1, 2, 3),
1675+
VA[
1676+
VA[8, [1, 2, 3], [1 2;3 4], VA[1, 2, 3]],
1677+
fill(2, 3, 4),
1678+
VA[3ones(3), zeros(3)],
1679+
],
1680+
]
1681+
```
1682+
1683+
"""
1684+
struct VA end
1685+
Base.getindex(::Type{VA}, xs...) = VectorOfArray(collect(xs))

test/basic_indexing.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ testa = cat(recs..., dims = 2)
66
testva = VectorOfArray(recs)
77
@test maximum(testva) == maximum(maximum.(recs))
88

9+
testva = VA[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
10+
@test maximum(testva) == maximum(maximum.(recs))
11+
912
# broadcast with array
1013
X = rand(3, 3)
1114
mulX = sqrt.(abs.(testva .* X))
@@ -161,15 +164,15 @@ diffeq = DiffEqArray(recs, t)
161164
@test diffeq[:, (end - 1):end].t == t[(length(t) - 1):length(t)]
162165

163166
# Test views of heterogeneous arrays (issue #453)
164-
f = VectorOfArray([[1.0], [2.0, 3.0]])
167+
f = VA[[1.0], [2.0, 3.0]]
165168
@test length(view(f, :, 1)) == 1
166169
@test length(view(f, :, 2)) == 2
167170
@test view(f, :, 1) == [1.0]
168171
@test view(f, :, 2) == [2.0, 3.0]
169172
@test collect(view(f, :, 1)) == f[:, 1]
170173
@test collect(view(f, :, 2)) == f[:, 2]
171174

172-
f2 = VectorOfArray([[1.0, 2.0], [3.0]])
175+
f2 = VA[[1.0, 2.0], [3.0]]
173176
@test length(view(f2, :, 1)) == 2
174177
@test length(view(f2, :, 2)) == 1
175178
@test view(f2, :, 1) == [1.0, 2.0]
@@ -178,7 +181,7 @@ f2 = VectorOfArray([[1.0, 2.0], [3.0]])
178181
@test collect(view(f2, :, 2)) == f2[:, 2]
179182

180183
# Test `end` with ragged arrays
181-
ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
184+
ragged = VA[[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]]
182185
@test ragged[end, 1] == 2.0
183186
@test ragged[end, 2] == 5.0
184187
@test ragged[end, 3] == 9.0
@@ -192,7 +195,7 @@ ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
192195
@test ragged[:, 2:end] == VectorOfArray(ragged.u[2:end])
193196
@test ragged[:, (end - 1):end] == VectorOfArray(ragged.u[(end - 1):end])
194197

195-
ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
198+
ragged2 = VA[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]]
196199
@test ragged2[end, 1] == 4.0
197200
@test ragged2[end, 2] == 6.0
198201
@test ragged2[end, 3] == 9.0
@@ -228,7 +231,7 @@ ragged_range_idx = 1:lastindex(ragged, 1)
228231
@test identity.(ragged_range_idx) === ragged_range_idx
229232

230233
# Broadcasting of heterogeneous arrays (issue #454)
231-
u = VectorOfArray([[1.0], [2.0, 3.0]])
234+
u = VA[[1.0], [2.0, 3.0]]
232235
@test length(view(u, :, 1)) == 1
233236
@test length(view(u, :, 2)) == 2
234237
# broadcast assignment into selected column (last index Int)
@@ -293,7 +296,7 @@ u[1:2, 1, [1, 3], 2] .= [1.0 3.0; 2.0 4.0]
293296
@test u[end, 1, end] == u.u[end][end, 1, end]
294297

295298
# Test that views can be modified
296-
f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])
299+
f3 = VA[[1.0, 2.0], [3.0, 4.0, 5.0]]
297300
v = view(f3, :, 2)
298301
@test length(v) == 3
299302
v[1] = 10.0
@@ -384,7 +387,7 @@ mulX .= sqrt.(abs.(testva .* testvb))
384387
@test mulX == ref
385388

386389
# https://github.com/SciML/RecursiveArrayTools.jl/issues/49
387-
a = ArrayPartition(1:5, 1:6)
390+
a = AP[1:5, 1:6]
388391
a[1:8]
389392
a[[1, 3, 8]]
390393

test/downstream/downstream_events.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools
2-
u0 = ArrayPartition(SVector{1}(50.0), SVector{1}(0.0))
2+
u0 = AP[SVector{1}(50.0), SVector{1}(0.0)]
33
tspan = (0.0, 15.0)
44

55
function f(u, p, t)
6-
return ArrayPartition(SVector{1}(u[2]), SVector{1}(-9.81))
6+
return AP[SVector{1}(u[2]), SVector{1}(-9.81)]
77
end
88

99
prob = ODEProblem(f, u0, tspan)
@@ -13,7 +13,7 @@ function condition(u, t, integrator) # Event when event_f(u,t,k) == 0
1313
end
1414

1515
affect! = nothingf = affect_neg! = function (integrator)
16-
return integrator.u = ArrayPartition(SVector{1}(integrator.u[1]), SVector{1}(-integrator.u[2]))
16+
return integrator.u = AP[SVector{1}(integrator.u[1]), SVector{1}(-integrator.u[2])]
1717
end
1818

1919
callback = ContinuousCallback(condition, affect!, affect_neg!, interp_points = 100)

test/downstream/odesolve.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function lorenz(du, u, p, t)
44
du[2] = u[1] * (28.0 - u[3]) - u[2]
55
return du[3] = u[1] * u[2] - (8 / 3) * u[3]
66
end
7-
u0 = ArrayPartition([1.0, 0.0], [0.0])
7+
u0 = AP[[1.0, 0.0], [0.0]]
88
@test ArrayInterface.zeromatrix(u0) isa Matrix
99
tspan = (0.0, 100.0)
1010
prob = ODEProblem(lorenz, u0, tspan)
@@ -25,26 +25,26 @@ function mymodel(F, vars)
2525
return
2626
end
2727
# To show that the function works
28-
F = ArrayPartition([0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0])
29-
u0 = ArrayPartition([0.1 1.2; 0.1 1.2], [0.1 1.2; 0.1 1.2])
28+
F = AP[[0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0]]
29+
u0 = AP[[0.1 1.2; 0.1 1.2], [0.1 1.2; 0.1 1.2]]
3030
result = mymodel(F, u0)
3131
nlsolve(mymodel, u0)
3232

3333
# Nested ArrayPartition solves
3434

3535
function dyn(u, p, t)
3636
return ArrayPartition(
37-
ArrayPartition(zeros(1), [0.0]),
38-
ArrayPartition(zeros(1), [0.0])
37+
AP[zeros(1), [0.0]],
38+
AP[zeros(1), [0.0]]
3939
)
4040
end
4141

4242
@test solve(
4343
ODEProblem(
4444
dyn,
4545
ArrayPartition(
46-
ArrayPartition(zeros(1), [-1.0]),
47-
ArrayPartition(zeros(1), [0.75])
46+
AP[zeros(1), [-1.0]],
47+
AP[zeros(1), [0.75]]
4848
),
4949
(0.0, 1.0)
5050
),
@@ -55,8 +55,8 @@ end
5555
ODEProblem(
5656
dyn,
5757
ArrayPartition(
58-
ArrayPartition(zeros(1), [-1.0]),
59-
ArrayPartition(zeros(1), [0.75])
58+
AP[zeros(1), [-1.0]],
59+
AP[zeros(1), [0.75]]
6060
),
6161
(0.0, 1.0)
6262
),

0 commit comments

Comments
 (0)