|
| 1 | +test_result(a::AbstractArray, b::AbstractArray; kwargs...) = |
| 2 | + isapprox(Array(a), Array(b); kwargs...) |
| 3 | +test_result(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...) |
| 4 | + |
| 5 | +function compare(f, AT::Type, xs...; kwargs...) |
| 6 | + cpu_in = map(deepcopy, xs) # copy on CPU |
| 7 | + gpu_in = map(adapt(AT), xs) # adapt on GPU |
| 8 | + |
| 9 | + cpu_out = f(cpu_in...) |
| 10 | + gpu_out = f(gpu_in...) |
| 11 | + return test_result(cpu_out, gpu_out; kwargs...) |
| 12 | +end |
| 13 | + |
| 14 | +# types to test for |
| 15 | +ATs = [] |
| 16 | +!is_buildkite && push!(ATs, JLArray) |
| 17 | +CUDA.functional() && push!(ATs, CuArray) |
| 18 | +AMDGPU.functional() && push!(ATs, ROCArray) |
| 19 | +Metal.functional() && push!(ATs, MtlArray) |
| 20 | + |
| 21 | +@testset "in-place matrix operations ($AT)" for AT in ATs |
| 22 | + for T in (Float32, ComplexF32) |
| 23 | + A1 = StridedView(randn(T, 20, 20)) |
| 24 | + A2 = StridedView(randn(T, 20, 20)) |
| 25 | + |
| 26 | + @test compare(conj!, AT, A1) |
| 27 | + @test compare(adjoint!, AT, A1, A2) |
| 28 | + @test compare(transpose!, AT, A1, A2) |
| 29 | + @test compare((x, y) -> permutedims!(x, y, (2, 1)), AT, A1, A2) |
| 30 | + |
| 31 | + B1 = A1[4:4:end, 1:4:end] |
| 32 | + B2 = A2[4:4:end, 1:4:end] |
| 33 | + |
| 34 | + @test compare(conj!, AT, B1) |
| 35 | + @test compare(adjoint!, AT, B1, B2) |
| 36 | + @test compare(transpose!, AT, B1, B2) |
| 37 | + @test compare((x, y) -> permutedims!(x, y, (2, 1)), AT, B1, B2) |
| 38 | + end |
| 39 | +end |
| 40 | + |
| 41 | +@testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs |
| 42 | + for T in (Float32, ComplexF32) |
| 43 | + for N in 2:6 |
| 44 | + dims = ntuple(Returns(div(60, N)), N) |
| 45 | + A1 = permutedims(StridedView(rand(T, dims)), randperm(N)) |
| 46 | + A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) |
| 47 | + A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) |
| 48 | + |
| 49 | + @test compare(x -> rmul!(x, 1 // 2), AT, A1) |
| 50 | + @test compare(x -> lmul!(1 // 3, x), AT, A2) |
| 51 | + @test compare((x, y) -> axpy!(1 // 3, x, y), AT, A1, A2) |
| 52 | + @test compare((x, y) -> axpby!(1 // 3, x, 1 // 2, y), AT, A1, A2) |
| 53 | + @test compare((x, y, z) -> map((a, b, c) -> sin(a) + b / exp(-abs(c)), x, y, z), AT, A1, A2, A3) |
| 54 | + @test compare((x, y) -> mul!(x, 1, y), AT, A1, A2) |
| 55 | + @test compare((x, y) -> mul!(x, y, 1), AT, A1, A2) |
| 56 | + end |
| 57 | + |
| 58 | + dims = ntuple(Returns(20), 2) |
| 59 | + A1 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2)) |
| 60 | + A2 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2)) |
| 61 | + A3 = permutedims(StridedView(rand(T, dims))[2:2:end, 2:2:end], randperm(2)) |
| 62 | + @test compare(x -> rmul!(x, 1 // 2), AT, A1) |
| 63 | + @test compare(x -> lmul!(1 // 3, x), AT, A2) |
| 64 | + @test compare((x, y) -> axpy!(1 // 3, x, y), AT, A1, A2) |
| 65 | + @test compare((x, y) -> axpby!(1 // 3, x, 1 // 2, y), AT, A1, A2) |
| 66 | + @test compare((x, y, z) -> map((a, b, c) -> sin(a) + b / exp(-abs(c)), x, y, z), AT, A1, A2, A3) |
| 67 | + @test compare((x, y) -> mul!(x, 1, y), AT, A1, A2) |
| 68 | + @test compare((x, y) -> mul!(x, y, 1), AT, A1, A2) |
| 69 | + end |
| 70 | +end |
| 71 | + |
| 72 | +@testset "broadcasting ($AT)" for AT in ATs |
| 73 | + for T in (Float32, ComplexF32) |
| 74 | + A0 = StridedView(rand(T, ())) |
| 75 | + A1 = StridedView(rand(T, (10,))) |
| 76 | + A2 = permutedims(StridedView(rand(T, (10, 10))), randperm(2)) |
| 77 | + A3 = permutedims(StridedView(rand(T, (10, 10, 10))), randperm(3)) |
| 78 | + A4 = StridedView(rand(T, (2, 0))) |
| 79 | + |
| 80 | + @test compare((x, y) -> x .+ sin.(y .- 3), AT, A1, A2) |
| 81 | + @test compare((y, z) -> y' .* z .- Ref(1 // 2), AT, A2, A3) |
| 82 | + @test compare((x, y, z) -> y' .* z .- max.(abs.(x), real.(z)), AT, A1, A2, A3) |
| 83 | + @test compare((u, y, z) -> y' .* z .- u, AT, A0, A2, A3) |
| 84 | + |
| 85 | + @test compare(x -> x .+ x, AT, A4) |
| 86 | + end |
| 87 | +end |
| 88 | + |
| 89 | +@testset "mapreduce ($AT)" for AT in ATs |
| 90 | + sz = 10 |
| 91 | + N = 6 |
| 92 | + for T in (Float32, ComplexF32) |
| 93 | + A1 = StridedView(rand(T, ntuple(Returns(sz), N))) |
| 94 | + |
| 95 | + @test compare(x -> sum(x; dims = (1, 3, 5)), AT, A1) |
| 96 | + @test compare(x -> mapreduce(sin, +, x; dims = (1, 3, 5)), AT, A1) |
| 97 | + @test compare(x -> sum(x; dims = (1, 3, 5)), AT, permutedims(A1, randperm(N))) |
| 98 | + @test compare(x -> mapreduce(sin, +, x; dims = (1, 3, 5)), AT, permutedims(A1, randperm(N))) |
| 99 | + |
| 100 | + A2 = sreshape(StridedView(rand(T, ntuple(Returns(sz), 3))), (sz, 1, 1, sz, sz, 1)) |
| 101 | + |
| 102 | + @test compare((x, y) -> Strided._mapreducedim!(sin, +, identity, ntuple(Returns(sz), N), (x, y)), AT, A1, A2) |
| 103 | + @test compare((x, y) -> Strided._mapreducedim!(sin, +, Returns(0), ntuple(Returns(sz), N), (x, y)), AT, A1, A2) |
| 104 | + @test compare((x, y) -> Strided._mapreducedim!(sin, +, conj, ntuple(Returns(sz), N), (x, y)), AT, A1, A2) |
| 105 | + |
| 106 | + β = rand(T) |
| 107 | + @test compare((x, y) -> Strided._mapreducedim!(sin, +, a -> β, ntuple(Returns(sz), N), (x, y)), AT, A1, A2) |
| 108 | + @test compare((x, y) -> Strided._mapreducedim!(sin, +, a -> β * a, ntuple(Returns(sz), N), (x, y)), AT, A1, A2) |
| 109 | + end |
| 110 | +end |
| 111 | + |
| 112 | +@testset "reduce ($AT)" for AT in ATs |
| 113 | + N = 4 |
| 114 | + for T in (Float32, ComplexF32) |
| 115 | + A1 = StridedView(rand(T, ntuple(Returns(10), N))) |
| 116 | + A2 = permutedims(StridedView(rand(T, ntuple(Returns(10), N))), randperm(N)) |
| 117 | + @test compare(sum, AT, A1) |
| 118 | + @test compare(sum, AT, A2) |
| 119 | + @test compare(x -> maximum(real, x), AT, A1) |
| 120 | + @test compare(x -> maximum(abs, x), AT, A2) |
| 121 | + @test compare(x -> minimum(abs, x), AT, A1) |
| 122 | + @test compare(x -> minimum(real, x), AT, A2) |
| 123 | + |
| 124 | + A3 = StridedView(rand(T, (5, 5, 5))) |
| 125 | + @test compare(x -> prod(exp, x), AT, A3) |
| 126 | + end |
| 127 | +end |
0 commit comments