Skip to content

Commit e1ea618

Browse files
kshyattlkdvos
andauthored
Test svd_trunc for GPU (#146)
* Test svd_trunc for GPU * scalar indexing in tests * bypass intersect on the GPU * more scalar indexing in tests * Get rid of GPUArrays * Try to unbreak AMDGPU * Actually fix * Format * Try generating AMD unitary special-case * AMDGPU hates complex rand --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 4f5bcb1 commit e1ea618

6 files changed

Lines changed: 41 additions & 22 deletions

File tree

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5656
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5757

5858
[targets]
59-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
59+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
60+
"ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,9 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
167167
return C
168168
end
169169

170+
# TODO: intersect on GPU arrays is not working
171+
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
172+
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
173+
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
174+
170175
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,9 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
191191
return C
192192
end
193193

194+
# TODO: intersect on GPU arrays is not working
195+
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
196+
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
197+
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
198+
194199
end

test/svd.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using TestExtras
44
using StableRNGs
55
using LinearAlgebra: Diagonal
66
using CUDA, AMDGPU
7+
using CUDA.CUSOLVER # pull in opnorm binding
78

89
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
910
GenericFloats = (BigFloat, Complex{BigFloat})
@@ -17,28 +18,28 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
1718
TestSuite.seed_rng!(123)
1819
if T BLASFloats
1920
if CUDA.functional()
20-
TestSuite.test_svd(CuMatrix{T}, (m, n); test_trunc = false)
21+
TestSuite.test_svd(CuMatrix{T}, (m, n))
2122
CUDA_SVD_ALGS = (
2223
CUSOLVER_QRIteration(),
2324
CUSOLVER_SVDPolar(),
2425
CUSOLVER_Jacobi(),
2526
)
26-
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS; test_trunc = false)
27+
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS)
2728
if n == m
28-
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m; test_trunc = false)
29-
TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
29+
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m)
30+
TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
3031
end
3132
end
3233
if AMDGPU.functional()
33-
TestSuite.test_svd(ROCMatrix{T}, (m, n); test_trunc = false)
34+
TestSuite.test_svd(ROCMatrix{T}, (m, n))
3435
AMD_SVD_ALGS = (
3536
ROCSOLVER_QRIteration(),
3637
ROCSOLVER_Jacobi(),
3738
)
38-
TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS; test_trunc = false)
39+
TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS)
3940
if n == m
40-
TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m; test_trunc = false)
41-
TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
41+
TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m)
42+
TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
4243
end
4344
end
4445
end

test/testsuite/TestSuite.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), co
7777
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))
7878

7979
instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1]
80+
# AMDGPU can't generate ComplexF32 random numbers
81+
function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz)
82+
sqA = randn!(similar(A, real(eltype(T)), sz, sz)) .+ im .* randn!(similar(A, real(eltype(T)), sz, sz))
83+
return qr_compact(sqA)[1]
84+
end
8085
instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A))))
8186

8287
include("qr.jl")

test/testsuite/svd.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ using TestExtras
22
using GenericLinearAlgebra
33
using LinearAlgebra: opnorm
44

5-
function test_svd(T::Type, sz; test_trunc = true, kwargs...)
5+
function test_svd(T::Type, sz; kwargs...)
66
summary_str = testargs_summary(T, sz)
77
return @testset "svd $summary_str" begin
88
test_svd_compact(T, sz; kwargs...)
99
test_svd_full(T, sz; kwargs...)
10-
test_trunc && test_svd_trunc(T, sz; kwargs...)
10+
test_svd_trunc(T, sz; kwargs...)
1111
end
1212
end
1313

14-
function test_svd_algs(T::Type, sz, algs; test_trunc = true, kwargs...)
14+
function test_svd_algs(T::Type, sz, algs; kwargs...)
1515
summary_str = testargs_summary(T, sz)
1616
return @testset "svd algorithms $summary_str" begin
1717
test_svd_compact_algs(T, sz, algs; kwargs...)
1818
test_svd_full_algs(T, sz, algs; kwargs...)
19-
test_trunc && test_svd_trunc_algs(T, sz, algs; kwargs...)
19+
test_svd_trunc_algs(T, sz, algs; kwargs...)
2020
end
2121
end
2222

@@ -160,14 +160,15 @@ function test_svd_trunc(
160160
Ac = deepcopy(A)
161161
m, n = size(A)
162162
minmn = min(m, n)
163-
S₀ = svd_vals(A)
163+
S₀ = collect(svd_vals(A))
164164
r = minmn - 2
165165

166166
if m > 0 && n > 0
167167
U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r))
168168
@test length(diagview(S1)) == r
169-
@test diagview(S1) S₀[1:r]
170-
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
169+
@test collect(diagview(S1)) S₀[1:r]
170+
AUSV_vals = svd_vals(A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
171+
@test mapreduce(sv -> opnorm(sv, 2), max, AUSV_vals) S₀[r + 1]
171172
# Test truncation error
172173
@test ϵ1 norm(view(S₀, (r + 1):minmn)) atol = atol
173174

@@ -241,14 +242,15 @@ function test_svd_trunc_algs(
241242
Ac = deepcopy(A)
242243
m, n = size(A)
243244
minmn = min(m, n)
244-
S₀ = svd_vals(A)
245+
S₀ = collect(svd_vals(A))
245246
r = minmn - 2
246247

247248
if m > 0 && n > 0
248249
U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r), alg)
249250
@test length(diagview(S1)) == r
250-
@test diagview(S1) S₀[1:r]
251-
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
251+
@test collect(diagview(S1)) S₀[1:r]
252+
AUSV_vals = svd_vals(A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
253+
@test mapreduce(sv -> opnorm(sv, 2), max, AUSV_vals) S₀[r + 1]
252254
# Test truncation error
253255
@test ϵ1 norm(view(S₀, (r + 1):minmn)) atol = atol
254256

@@ -285,11 +287,11 @@ function test_svd_trunc_algs(
285287
)
286288
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg)
287289
@test length(diagview(S1)) == 1
288-
@test diagview(S1) diagview(S)[1:1]
290+
@test collect(diagview(S1)) collect(diagview(S)[1:1])
289291

290292
U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 3), alg)
291293
@test length(diagview(S2)) == 2
292-
@test diagview(S2) diagview(S)[1:2]
294+
@test collect(diagview(S2)) collect(diagview(S)[1:2])
293295
end
294296
end
295297
@testset "specify truncation algorithm" begin
@@ -303,7 +305,7 @@ function test_svd_trunc_algs(
303305
A = U * S * Vᴴ
304306
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = 0.2))
305307
U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg = truncalg)
306-
@test diagview(S2) diagview(S)[1:2]
308+
@test collect(diagview(S2)) collect(diagview(S)[1:2])
307309
@test ϵ2 norm(diagview(S)[3:4]) atol = atol
308310
@test_throws ArgumentError svd_trunc(A; alg = truncalg, trunc = (; maxrank = 2))
309311
@test_throws ArgumentError svd_trunc_no_error(A; alg = truncalg, trunc = (; maxrank = 2))

0 commit comments

Comments
 (0)