Skip to content

Commit c17ec1c

Browse files
committed
Add more tests
1 parent 4fb1602 commit c17ec1c

3 files changed

Lines changed: 49 additions & 2 deletions

File tree

test/runtests.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,47 @@ using BlockBandedMatrices
22
using ParallelTestRunner
33

44
const init_code = quote
5-
using Test
6-
using BlockBandedMatrices
5+
using Test
6+
using BlockBandedMatrices
7+
8+
function strided_ptr(f, a::AbstractArray{T}) where {T}
9+
a_cconv = Base.cconvert(Ptr{T}, a)
10+
GC.@preserve a_cconv begin
11+
f(Base.unsafe_convert(Ptr{T}, a_cconv))
12+
end
13+
end
14+
15+
"""
16+
check_strided_get(a::AbstractArray{T,N})
17+
18+
Test that array `a` implements the strided array interface for reading.
19+
Checks stride consistency and that `unsafe_load` matches `getindex`.
20+
"""
21+
function check_strided_get(a::AbstractArray{T,N})::Nothing where {T, N}
22+
if !isbitstype(eltype(a))
23+
error("a doesn't have isbits elements")
24+
end
25+
# Putting strided_ptr before the loop means that strided_ptr shouldn't error for empty arrays
26+
strided_ptr(a) do a_ptr
27+
for d in 1:N
28+
if stride(a, d) != strides(a)[d]
29+
error("stride(a, d) doesn't equal strides(a)[d] for dimension $(d)")
30+
end
31+
end
32+
for i in CartesianIndices(a)
33+
el_ptr = a_ptr
34+
for d in 1:N
35+
stride_in_bytes = stride(a, d) * Base.elsize(typeof(a))
36+
first_idx = first(axes(a, d))
37+
el_ptr += (i[d] - first_idx) * stride_in_bytes
38+
end
39+
if unsafe_load(el_ptr) !== a[i]
40+
error("getindex and unsafe_load mismatch at index $(i)")
41+
end
42+
end
43+
end
44+
nothing
45+
end
746
end
847

948
# Start with autodiscovered tests

test/test_linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Test
1010

1111
import BandedMatrices: BandError, bandeddata
1212
import BlockBandedMatrices: _BandedBlockBandedMatrix
13+
import ..check_strided_get
1314

1415
@testset "lmul!/rmul!" begin
1516
C = BandedBlockBandedMatrix{Float64}(undef, 1:2,1:2, (1,1), (1,1))
@@ -45,6 +46,7 @@ end
4546
@test stride(V,2) == 7
4647
@test unsafe_load(pointer(V)) == 46
4748
@test unsafe_load(pointer(V) + stride(V,2)*sizeof(Float64)) == 53
49+
check_strided_get(V)
4850

4951
x = randn(size(A,2))
5052
@test A*x == (similar(x) .= MulAdd(A,x)) Matrix(A)*x
@@ -87,6 +89,7 @@ end
8789

8890
V = view(A, Block(2), Block(2))
8991
@test unsafe_load(Base.unsafe_convert(Ptr{Float64}, bandeddata(V))) == 13.0
92+
check_strided_get(bandeddata(V))
9093

9194
C = BandedMatrix{Float64}(undef, size(V), 2 .*bandwidths(V))
9295
C .= MulAdd(V,V)

test/test_triblockbanded.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import BlockBandedMatrices: MemoryLayout, TriangularLayout,
1212
blockrowstop, blockcolstop, ColumnMajor
1313

1414
import BlockArrays: blockisequal
15+
import ..check_strided_get
1516

1617
@testset "triangular" begin
1718
@testset "triangular BandedBlockBandedMatrix mul" begin
@@ -156,6 +157,7 @@ import BlockArrays: blockisequal
156157
@test unsafe_load(pointer(V)) == A[2,4]
157158
@test unsafe_load(pointer(V)+sizeof(Float64)*stride(V,2)) == A[2,5]
158159
@test MemoryLayout(typeof(V)) == ColumnMajor()
160+
check_strided_get(V)
159161

160162
@test size(V) == (5,3)
161163
b = randn(size(V,2))
@@ -174,6 +176,7 @@ import BlockArrays: blockisequal
174176
@test unsafe_load(pointer(V)) == A[2,5]
175177
@test unsafe_load(pointer(V)+sizeof(Float64)*stride(V,2)) == A[2,6]
176178
@test MemoryLayout(typeof(V)) == ColumnMajor()
179+
check_strided_get(V)
177180

178181
@test size(V) == (5,2)
179182
b = randn(size(V,2))
@@ -194,6 +197,8 @@ import BlockArrays: blockisequal
194197
V_22 = view(A, Block(N)[1:N], Block(N)[1:N])
195198
@test unsafe_load(pointer(V_22)) == V_22[1,1] == V[1,1]
196199
@test strides(V_22) == strides(V) == (1,9)
200+
check_strided_get(V_22)
201+
check_strided_get(V)
197202
b = randn(N)
198203
@test copyto!(similar(b) , MulAdd(V,b)) == copyto!(similar(b) , MulAdd(V_22,b)) ==
199204
Matrix(V)*b ==

0 commit comments

Comments
 (0)