Skip to content

Commit 5c3025e

Browse files
committed
Add FillArrays extension
1 parent d81ef3b commit 5c3025e

4 files changed

Lines changed: 205 additions & 1 deletion

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1010

1111
[weakdeps]
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1314

1415
[extensions]
1516
StridedViewsCUDAExt = "CUDA"
17+
StridedViewsFillArraysExt = "FillArrays"
1618

1719
[compat]
1820
Aqua = "0.8"
@@ -22,11 +24,13 @@ PackageExtensionCompat = "1"
2224
Random = "1.6"
2325
Test = "1.6"
2426
julia = "1.6"
27+
FillArrays = "1"
2528

2629
[extras]
2730
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2831
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2932
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
33+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3034

3135
[targets]
32-
test = ["Test", "Random", "Aqua"]
36+
test = ["Test", "Random", "Aqua", "FillArrays"]

ext/StridedViewsFillArraysExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module StridedViewsFillArraysExt
2+
3+
using StridedViews
4+
using FillArrays
5+
using FillArrays: AbstractFill, getindex_value
6+
7+
_strides(A::AbstractFill) = (1, Base.size_to_strides(size(A)...)...)
8+
_strides(::AbstractFill{T,0}) where {T} = ( )
9+
10+
function StridedViews.StridedView(parent::A, sz::NTuple{N,Int}=size(parent),
11+
st::NTuple{N,Int}=_strides(parent),
12+
offset::Int=0, op::F=identity) where {A<:AbstractFill,N,F}
13+
T = Base.promote_op(op, eltype(parent))
14+
return StridedView{T,N,A,F}(parent, sz, st, offset, op)
15+
end
16+
17+
function FillArrays.getindex_value(a::StridedView{T,N,A}) where {T,N,A<:AbstractFill}
18+
return a.op(getindex_value(parent(a)))
19+
end
20+
21+
# short-circuit indexing to only call checkbounds, no index computation needed
22+
@inline function Base.getindex(a::StridedView{T,N,A}, I::Vararg{Int,N}) where {T,N,A<:AbstractFill}
23+
@boundscheck checkbounds(a, I...)
24+
return getindex_value(a)
25+
end
26+
@inline function Base.setindex!(a::StridedView{T,N,A}, v, I::Vararg{Int,N}) where {T,N,A<:AbstractFill}
27+
@boundscheck checkbounds(a, I...)
28+
v == getindex_value(a) || throw(ArgumentError("Cannot setindex! to $v for an AbstractFill with value $(getindex_value(a))."))
29+
return a
30+
end
31+
32+
end

test/fillarrays.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
module FillArrayTests
2+
3+
using Test
4+
using LinearAlgebra
5+
using StridedViews
6+
using FillArrays
7+
using Random
8+
Random.seed!(1234)
9+
10+
@testset "FillArrays" verbose=true begin
11+
@testset for T1 in (Float32, Float64, Complex{Float32}, Complex{Float64})
12+
A1 = Fill(rand(T1), (60, 60))
13+
B1 = StridedView(A1)
14+
C1 = StridedView(B1)
15+
@test C1 === B1
16+
@test parent(B1) === A1
17+
for op1 in (identity, conj, transpose, adjoint)
18+
if op1 == transpose || op1 == adjoint
19+
@test op1(A1) == op1(B1) == StridedView(op1(A1))
20+
else
21+
@test op1(A1) == op1(B1)
22+
end
23+
for op2 in (identity, conj, transpose, adjoint)
24+
@test op2(op1(A1)) == op2(op1(B1))
25+
end
26+
end
27+
28+
A2 = view(A1, 1:36, 1:20)
29+
B2 = StridedView(A2)
30+
for op1 in (identity, conj, transpose, adjoint)
31+
if op1 == transpose || op1 == adjoint
32+
@test op1(A2) == op1(B2) == StridedView(op1(A2))
33+
else
34+
@test op1(A2) == op1(B2)
35+
end
36+
for op2 in (identity, conj, transpose, adjoint)
37+
@test op2(op1(A2)) == op2(op1(B2))
38+
end
39+
end
40+
41+
A3 = reshape(A1, 360, 10)
42+
B3 = StridedView(A3)
43+
@test size(A3) == size(B3)
44+
for op1 in (identity, conj, transpose, adjoint)
45+
if op1 == transpose || op1 == adjoint
46+
@test op1(A3) == op1(B3) == StridedView(op1(A3))
47+
else
48+
@test op1(A3) == op1(B3)
49+
end
50+
for op2 in (identity, conj, transpose, adjoint)
51+
@test op2(op1(A3)) == op2(op1(B3))
52+
end
53+
end
54+
55+
A4 = reshape(view(A1, 1:36, 1:20), (6, 6, 5, 4))
56+
B4 = StridedView(A4)
57+
for op1 in (identity, conj)
58+
@test op1(A4) == op1(B4)
59+
for op2 in (identity, conj)
60+
@test op2(op1(A4)) == op2(op1(B4))
61+
end
62+
end
63+
64+
A5 = PermutedDimsArray(reshape(view(A1, 1:36, 1:20), (6, 6, 5, 4)), (3, 1, 2, 4))
65+
B5 = StridedView(A5)
66+
for op1 in (identity, conj)
67+
@test op1(A5) == op1(B5)
68+
for op2 in (identity, conj)
69+
@test op2(op1(A5)) == op2(op1(B5))
70+
end
71+
end
72+
73+
# Zero-dimensional array is currently broken, see https://github.com/JuliaArrays/FillArrays.jl/issues/145
74+
# A8 = Fill(rand(T1), ())
75+
# B8 = StridedView(A8)
76+
# @test stride(B8, 1) == stride(B8, 5) == 1
77+
# for op1 in (identity, conj)
78+
# @test op1(A8) == op1(B8) == StridedView(op1(A8))
79+
# for op2 in (identity, conj)
80+
# @test op2(op1(A8)) == op2(op1(B8))
81+
# end
82+
# end
83+
# @test reshape(B8, (1, 1, 1)) == reshape(A8, (1, 1, 1)) ==
84+
# StridedView(reshape(A8, (1, 1, 1))) == sreshape(A8, (1, 1, 1))
85+
# @test reshape(B8, ()) == reshape(A8, ())
86+
end
87+
88+
@testset "transpose and adjoint with vector StridedView" begin
89+
@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
90+
A = Fill(rand(T), (60,))
91+
92+
@test sreshape(transpose(A), (1, length(A))) == transpose(A)
93+
@test sreshape(adjoint(A), (1, length(A))) == adjoint(A)
94+
end
95+
end
96+
97+
@testset "reshape and permutedims with StridedView" begin
98+
@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
99+
@testset "reshape and permutedims with $N-dimensional arrays" for N in 2:6
100+
let dims = ntuple(n -> rand(1:div(60, N)), N)
101+
A = Fill(rand(T), dims)
102+
B = StridedView(A)
103+
@test conj(A) == conj(B)
104+
p = randperm(N)
105+
B2 = permutedims(B, p)
106+
A2 = permutedims(A, p)
107+
@test B2 == A2
108+
end
109+
110+
let dims = ntuple(n -> 10, N)
111+
A = Fill(rand(T), dims)
112+
B = StridedView(A)
113+
@test conj(A) == conj(B)
114+
p = randperm(N)
115+
B2 = permutedims(B, p)
116+
A2 = permutedims(A, p)
117+
@test B2 == A2
118+
119+
B2 = sreshape(B, (2, 5, ntuple(n -> 10, N - 2)..., 5, 2))
120+
A2 = sreshape(A, (2, 5, ntuple(n -> 10, N - 2)..., 5, 2)...)
121+
A3 = reshape(A, size(A2))
122+
@test B2 == A3
123+
@test B2 == A2
124+
p = randperm(N + 2)
125+
@test conj(permutedims(B2, p)) == conj(permutedims(A3, p))
126+
end
127+
end
128+
129+
@testset "more reshape" begin
130+
A = Ones(4, 0)
131+
B = StridedView(A)
132+
@test_throws DimensionMismatch sreshape(B, (4, 1))
133+
C = sreshape(B, (2, 1, 2, 0, 1))
134+
@test sreshape(C, (4, 0)) == A
135+
136+
A = Trues(4, 1, 2)
137+
B = StridedView(A)
138+
@test_throws DimensionMismatch sreshape(B, (4, 4))
139+
C = sreshape(B, (2, 1, 1, 4, 1, 1))
140+
@test C == reshape(A, (2, 1, 1, 4, 1, 1))
141+
@test sreshape(C, (4, 1, 2)) == A
142+
end
143+
end
144+
end
145+
146+
@testset "views with StridedView" begin
147+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
148+
A = Fill(rand(T), (10, 10, 10, 10))
149+
B = StridedView(A)
150+
@test isa(view(B, :, 1:5, 3, 1:5), StridedView)
151+
@test isa(view(B, :, [1, 2, 3], 3, 1:5), Base.SubArray)
152+
@test isa(sview(B, :, 1:5, 3, 1:5), StridedView)
153+
@test_throws MethodError sview(B, :, [1, 2, 3], 3, 1:5)
154+
155+
@test view(A, 1:38) == view(B, 1:38) == sview(A, 1:38) == sview(B, 1:38)
156+
157+
@test view(B, :, 1:5, 3, 1:5) == view(A, :, 1:5, 3, 1:5) ==
158+
sview(A, :, 1:5, 3, 1:5)
159+
@test view(B, :, 1:5, 3, 1:5) === sview(B, :, 1:5, 3, 1:5) === B[:, 1:5, 3, 1:5]
160+
@test view(B, :, 1:5, 3, 1:5) == StridedView(view(A, :, 1:5, 3, 1:5))
161+
@test StridedViews.offset(view(B, :, 1:5, 3, 1:5)) == 2 * stride(B, 3)
162+
end
163+
end
164+
end
165+
166+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,5 +214,7 @@ end
214214
end
215215
end
216216

217+
include("fillarrays.jl")
218+
217219
using Aqua
218220
Aqua.test_all(StridedViews)

0 commit comments

Comments
 (0)