Skip to content

Commit 17ea798

Browse files
authored
Support new projections on GPU (#81)
1 parent 1355804 commit 17ea798

5 files changed

Lines changed: 367 additions & 0 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,106 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::
5252
return MatrixAlgebraKit.findtruncated(values, strategy)
5353
end
5454

55+
# COV_EXCL_START
56+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
57+
m, n = size(Au)
58+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
59+
j > n && return
60+
for i in 1:m
61+
@inbounds begin
62+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
63+
Bu[i, j] = val
64+
Bl[j, i] = -adjoint(val)
65+
end
66+
end
67+
return
68+
end
69+
70+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
71+
m, n = size(Au)
72+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
73+
j > n && return
74+
for i in 1:m
75+
@inbounds begin
76+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
77+
Bu[i, j] = val
78+
Bl[j, i] = adjoint(val)
79+
end
80+
end
81+
return
82+
end
83+
84+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
85+
n = size(A, 1)
86+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
87+
j > n && return
88+
@inbounds begin
89+
for i in 1:(j - 1)
90+
val = (A[i, j] - adjoint(A[j, i])) / 2
91+
B[i, j] = val
92+
B[j, i] = -adjoint(val)
93+
end
94+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
95+
end
96+
return
97+
end
98+
99+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
100+
n = size(A, 1)
101+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
102+
j > n && return
103+
@inbounds begin
104+
for i in 1:(j - 1)
105+
val = (A[i, j] + adjoint(A[j, i])) / 2
106+
B[i, j] = val
107+
B[j, i] = adjoint(val)
108+
end
109+
B[j, j] = real(A[j, j])
110+
end
111+
return
112+
end
113+
# COV_EXCL_STOP
114+
115+
function MatrixAlgebraKit._project_hermitian_offdiag!(
116+
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
117+
) where {anti}
118+
thread_dim = 512
119+
block_dim = cld(size(Au, 2), thread_dim)
120+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
121+
return nothing
122+
end
123+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
124+
thread_dim = 512
125+
block_dim = cld(size(A, 1), thread_dim)
126+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
127+
return nothing
128+
end
129+
130+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
132+
133+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
134+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
135+
136+
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
137+
axes(A) == axes(B) || throw(DimensionMismatch())
138+
# COV_EXCL_START
139+
function _avgdiff_kernel(A, B)
140+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
141+
j > length(A) && return
142+
@inbounds begin
143+
a = A[j]
144+
b = B[j]
145+
A[j] = (a + b) / 2
146+
B[j] = b - a
147+
end
148+
return
149+
end
150+
# COV_EXCL_STOP
151+
thread_dim = 512
152+
block_dim = cld(length(A), thread_dim)
153+
@roc groupsize = thread_dim gridsize = block_dim _avgdiff_kernel(A, B)
154+
return A, B
155+
end
156+
55157
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1111
using CUDA
12+
using CUDA: i32
1213
using LinearAlgebra
1314
using LinearAlgebra: BlasFloat
1415

@@ -58,4 +59,106 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5859
return MatrixAlgebraKit.findtruncated(values, strategy)
5960
end
6061

62+
# COV_EXCL_START
63+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
64+
m, n = size(Au)
65+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
66+
j > n && return
67+
for i in 1:m
68+
@inbounds begin
69+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
70+
Bu[i, j] = val
71+
Bl[j, i] = -adjoint(val)
72+
end
73+
end
74+
return
75+
end
76+
77+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
78+
m, n = size(Au)
79+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
80+
j > n && return
81+
for i in 1:m
82+
@inbounds begin
83+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
84+
Bu[i, j] = val
85+
Bl[j, i] = adjoint(val)
86+
end
87+
end
88+
return
89+
end
90+
91+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
92+
n = size(A, 1)
93+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
94+
j > n && return
95+
@inbounds begin
96+
for i in 1i32:(j - 1i32)
97+
val = (A[i, j] - adjoint(A[j, i])) / 2
98+
B[i, j] = val
99+
B[j, i] = -adjoint(val)
100+
end
101+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
102+
end
103+
return
104+
end
105+
106+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
107+
n = size(A, 1)
108+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
109+
j > n && return
110+
@inbounds begin
111+
for i in 1i32:(j - 1i32)
112+
val = (A[i, j] + adjoint(A[j, i])) / 2
113+
B[i, j] = val
114+
B[j, i] = adjoint(val)
115+
end
116+
B[j, j] = real(A[j, j])
117+
end
118+
return
119+
end
120+
# COV_EXCL_STOP
121+
122+
function MatrixAlgebraKit._project_hermitian_offdiag!(
123+
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
124+
) where {anti}
125+
thread_dim = 512
126+
block_dim = cld(size(Au, 2), thread_dim)
127+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
128+
return nothing
129+
end
130+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
131+
thread_dim = 512
132+
block_dim = cld(size(A, 1), thread_dim)
133+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
134+
return nothing
135+
end
136+
137+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
138+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
139+
140+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
141+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
142+
143+
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
144+
axes(A) == axes(B) || throw(DimensionMismatch())
145+
# COV_EXCL_START
146+
function _avgdiff_kernel(A, B)
147+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
148+
j > length(A) && return
149+
@inbounds begin
150+
a = A[j]
151+
b = B[j]
152+
A[j] = (a + b) / 2
153+
B[j] = b - a
154+
end
155+
return
156+
end
157+
# COV_EXCL_STOP
158+
thread_dim = 512
159+
block_dim = cld(length(A), thread_dim)
160+
@cuda threads = thread_dim blocks = block_dim _avgdiff_kernel(A, B)
161+
return A, B
162+
end
163+
61164
end

test/amd/projections.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using AMDGPU
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = ROCArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
55+
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
56+
@testset "algorithm $alg" for alg in algs
57+
A = ROCArray(randn(rng, T, m, n))
58+
W = project_isometric(A, alg)
59+
@test isisometric(W)
60+
W2 = project_isometric(W, alg)
61+
@test W2 W # stability of the projection
62+
@test W * (W' * A) A
63+
64+
Ac = similar(A)
65+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66+
@test W2 === W
67+
@test isisometric(W)
68+
69+
# test that W is closer to A then any other isometry
70+
for k in 1:10
71+
δA = ROCArray(randn(rng, T, m, n))
72+
W = project_isometric(A, alg)
73+
W2 = project_isometric(A + δA / 100, alg)
74+
@test norm(A - W2) > norm(A - W)
75+
end
76+
end
77+
end
78+
end

test/cuda/projections.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using CUDA
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = CuArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
55+
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
56+
@testset "algorithm $alg" for alg in algs
57+
A = CuArray(randn(rng, T, m, n))
58+
W = project_isometric(A, alg)
59+
@test isisometric(W)
60+
W2 = project_isometric(W, alg)
61+
@test W2 W # stability of the projection
62+
@test W * (W' * A) A
63+
64+
Ac = similar(A)
65+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66+
@test W2 === W
67+
@test isisometric(W)
68+
69+
# test that W is closer to A then any other isometry
70+
for k in 1:10
71+
δA = CuArray(randn(rng, T, m, n))
72+
W = project_isometric(A, alg)
73+
W2 = project_isometric(A + δA / 100, alg)
74+
@test norm(A - W2) > norm(A - W)
75+
end
76+
end
77+
end
78+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ if CUDA.functional()
6363
@safetestset "CUDA LQ" begin
6464
include("cuda/lq.jl")
6565
end
66+
@safetestset "CUDA Projections" begin
67+
include("cuda/projections.jl")
68+
end
6669
@safetestset "CUDA SVD" begin
6770
include("cuda/svd.jl")
6871
end
@@ -82,6 +85,9 @@ if AMDGPU.functional()
8285
@safetestset "AMDGPU LQ" begin
8386
include("amd/lq.jl")
8487
end
88+
@safetestset "AMDGPU Projections" begin
89+
include("amd/projections.jl")
90+
end
8591
@safetestset "AMDGPU SVD" begin
8692
include("amd/svd.jl")
8793
end

0 commit comments

Comments
 (0)