Skip to content

Commit 7629000

Browse files
committed
clean up tests
1 parent 1aba36a commit 7629000

6 files changed

Lines changed: 274 additions & 387 deletions

File tree

test/amd.jl

Lines changed: 0 additions & 29 deletions
This file was deleted.

test/cuda.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

test/gpu.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
test_result(a::AbstractArray, b::AbstractArray; kwargs...) =
2+
isapprox(Array(a), Array(b); kwargs...)
3+
test_result(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...)
4+
5+
function compare(f, AT::Type, xs...; kwargs...)
6+
cpu_in = map(deepcopy, xs) # copy on CPU
7+
gpu_in = map(adapt(AT), xs) # adapt on GPU
8+
9+
cpu_out = f(cpu_in...)
10+
gpu_out = f(gpu_in...)
11+
return test_result(cpu_out, gpu_out; kwargs...)
12+
end
13+
14+
# types to test for
15+
ATs = []
16+
!is_buildkite && push!(ATs, JLArray)
17+
CUDA.functional() && push!(ATs, CuArray)
18+
AMDGPU.functional() && push!(ATs, ROCArray)
19+
Metal.functional() && push!(ATs, MtlArray)
20+
21+
@testset "in-place matrix operations ($AT)" for AT in ATs
22+
for T in (Float32, ComplexF32)
23+
A1 = StridedView(randn(T, 20, 20))
24+
A2 = StridedView(randn(T, 20, 20))
25+
26+
@test compare(conj!, AT, A1)
27+
@test compare(adjoint!, AT, A1, A2)
28+
@test compare(transpose!, AT, A1, A2)
29+
@test compare((x, y) -> permutedims!(x, y, (2, 1)), AT, A1, A2)
30+
31+
B1 = A1[4:4:end, 1:4:end]
32+
B2 = A2[4:4:end, 1:4:end]
33+
34+
@test compare(conj!, AT, B1)
35+
@test compare(adjoint!, AT, B1, B2)
36+
@test compare(transpose!, AT, B1, B2)
37+
@test compare((x, y) -> permutedims!(x, y, (2, 1)), AT, B1, B2)
38+
end
39+
end
40+
41+
@testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs
42+
for T in (Float32, ComplexF32)
43+
for N in 2:6
44+
dims = ntuple(Returns(div(60, N)), N)
45+
A1 = permutedims(StridedView(rand(T, dims)), randperm(N))
46+
A2 = permutedims(StridedView(rand(T, dims)), randperm(N))
47+
A3 = permutedims(StridedView(rand(T, dims)), randperm(N))
48+
49+
@test compare(x -> rmul!(x, 1 // 2), AT, A1)
50+
@test compare(x -> lmul!(1 // 3, x), AT, A2)
51+
@test compare((x, y) -> axpy!(1 // 3, x, y), AT, A1, A2)
52+
@test compare((x, y) -> axpby!(1 // 3, x, 1 // 2, y), AT, A1, A2)
53+
@test compare((x, y, z) -> map((a, b, c) -> sin(a) + b / exp(-abs(c)), x, y, z), AT, A1, A2, A3)
54+
@test compare((x, y) -> mul!(x, 1, y), AT, A1, A2)
55+
@test compare((x, y) -> mul!(x, y, 1), AT, A1, A2)
56+
end
57+
58+
dims = ntuple(Returns(20), 2)
59+
A1 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2))
60+
A2 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2))
61+
A3 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2))
62+
@test compare(x -> rmul!(x, 1 // 2), AT, A1)
63+
@test compare(x -> lmul!(1 // 3, x), AT, A2)
64+
@test compare((x, y) -> axpy!(1 // 3, x, y), AT, A1, A2)
65+
@test compare((x, y) -> axpby!(1 // 3, x, 1 // 2, y), AT, A1, A2)
66+
@test compare((x, y, z) -> map((a, b, c) -> sin(a) + b / exp(-abs(c)), x, y, z), AT, A1, A2, A3)
67+
@test compare((x, y) -> mul!(x, 1, y), AT, A1, A2)
68+
@test compare((x, y) -> mul!(x, y, 1), AT, A1, A2)
69+
end
70+
end
71+
72+
@testset "broadcasting ($AT)" for AT in ATs
73+
for T in (Float32, ComplexF32)
74+
A0 = StridedView(rand(T, ()))
75+
A1 = StridedView(rand(T, (10,)))
76+
A2 = permutedims(StridedView(rand(T, (10, 10))), randperm(2))
77+
A3 = permutedims(StridedView(rand(T, (10, 10, 10))), randperm(3))
78+
A4 = StridedView(rand(T, (2, 0)))
79+
80+
@test compare((x, y) -> x .+ sin.(y .- 3), AT, A1, A2)
81+
@test compare((y, z) -> y' .* z .- Ref(1 // 2), AT, A2, A3)
82+
@test compare((x, y, z) -> y' .* z .- max.(abs.(x), real.(z)), AT, A1, A2, A3)
83+
@test compare((u, y, z) -> y' .* z .- u, AT, A0, A2, A3)
84+
85+
@test compare(x -> x .+ x, AT, A4)
86+
end
87+
end
88+
89+
@testset "mapreduce ($AT)" for AT in ATs
90+
sz = 10
91+
N = 6
92+
for T in (Float32, ComplexF32)
93+
A1 = StridedView(rand(T, ntuple(Returns(sz), N)))
94+
95+
@test compare(x -> sum(x; dims = (1, 3, 5)), AT, A1)
96+
@test compare(x -> mapreduce(sin, +, x; dims = (1, 3, 5)), AT, A1)
97+
@test compare(x -> sum(x; dims = (1, 3, 5)), AT, permutedims(A1, randperm(N)))
98+
@test compare(x -> mapreduce(sin, +, x; dims = (1, 3, 5)), AT, permutedims(A1, randperm(N)))
99+
100+
A2 = sreshape(StridedView(rand(T, ntuple(Returns(sz), 3))), (sz, 1, 1, sz, sz, 1))
101+
102+
@test compare((x, y) -> Strided._mapreducedim!(sin, +, identity, ntuple(Returns(sz), N), (x, y)), AT, A1, A2)
103+
@test compare((x, y) -> Strided._mapreducedim!(sin, +, Returns(0), ntuple(Returns(sz), N), (x, y)), AT, A1, A2)
104+
@test compare((x, y) -> Strided._mapreducedim!(sin, +, conj, ntuple(Returns(sz), N), (x, y)), AT, A1, A2)
105+
106+
β = rand(T)
107+
@test compare((x, y) -> Strided._mapreducedim!(sin, +, a -> β, ntuple(Returns(sz), N), (x, y)), AT, A1, A2)
108+
@test compare((x, y) -> Strided._mapreducedim!(sin, +, a -> β * a, ntuple(Returns(sz), N), (x, y)), AT, A1, A2)
109+
end
110+
end
111+
112+
@testset "reduce ($AT)" for AT in ATs
113+
N = 4
114+
for T in (Float32, ComplexF32)
115+
A1 = StridedView(rand(T, ntuple(Returns(10), N)))
116+
A2 = permutedims(StridedView(rand(T, ntuple(Returns(10), N))), randperm(N))
117+
@test compare(sum, AT, A1)
118+
@test compare(sum, AT, A2)
119+
@test compare(x -> maximum(real, x), AT, A1)
120+
@test compare(x -> maximum(abs, x), AT, A2)
121+
@test compare(x -> minimum(abs, x), AT, A1)
122+
@test compare(x -> minimum(real, x), AT, A2)
123+
124+
A3 = StridedView(rand(T, (5, 5, 5)))
125+
@test compare(x -> prod(exp, x), AT, A3)
126+
end
127+
end

test/mapreduce_tests.jl

Lines changed: 0 additions & 145 deletions
This file was deleted.

0 commit comments

Comments
 (0)