Skip to content

Commit 6b9336b

Browse files
committed
Support new projections on GPU
1 parent ba9867b commit 6b9336b

5 files changed

Lines changed: 362 additions & 0 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::
5252
return MatrixAlgebraKit.findtruncated(values, strategy)
5353
end
5454

55+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
56+
m, n = size(Au)
57+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
58+
j > n && return
59+
for i in 1:m
60+
@inbounds begin
61+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
62+
Bu[i, j] = val
63+
Bl[j, i] = -adjoint(val)
64+
end
65+
end
66+
return
67+
end
68+
69+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
70+
m, n = size(Au)
71+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
72+
j > n && return
73+
for i in 1:m
74+
@inbounds begin
75+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
76+
Bu[i, j] = val
77+
Bl[j, i] = adjoint(val)
78+
end
79+
end
80+
return
81+
end
82+
83+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
84+
n = size(A, 1)
85+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
86+
j > n && return
87+
@inbounds begin
88+
for i in 1:(j - 1)
89+
val = (A[i, j] - adjoint(A[j, i])) / 2
90+
B[i, j] = val
91+
B[j, i] = -adjoint(val)
92+
end
93+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
94+
end
95+
return
96+
end
97+
98+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
99+
n = size(A, 1)
100+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
101+
j > n && return
102+
@inbounds begin
103+
for i in 1:(j - 1)
104+
val = (A[i, j] + adjoint(A[j, i])) / 2
105+
B[i, j] = val
106+
B[j, i] = adjoint(val)
107+
end
108+
B[j, j] = real(A[j, j])
109+
end
110+
return
111+
end
112+
113+
function MatrixAlgebraKit._project_hermitian_offdiag!(
114+
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
115+
) where {anti}
116+
thread_dim = 512
117+
block_dim = cld(size(Au, 2), thread_dim)
118+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
119+
return nothing
120+
end
121+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
122+
thread_dim = 512
123+
block_dim = cld(size(A, 1), thread_dim)
124+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
125+
return nothing
126+
end
127+
128+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
129+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
130+
131+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
132+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
133+
134+
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
135+
axes(A) == axes(B) || throw(DimensionMismatch())
136+
function _avgdiff_kernel(A, B)
137+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
138+
j > length(A) && return
139+
@inbounds begin
140+
a = A[j]
141+
b = B[j]
142+
A[j] = (a + b) / 2
143+
B[j] = b - a
144+
end
145+
return
146+
end
147+
thread_dim = 512
148+
block_dim = cld(length(A), thread_dim)
149+
@roc groupsize = thread_dim gridsize = block_dim _avgdiff_kernel(A, B)
150+
return A, B
151+
end
152+
55153
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5858
return MatrixAlgebraKit.findtruncated(values, strategy)
5959
end
6060

61+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
62+
m, n = size(Au)
63+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
64+
j > n && return
65+
for i in 1:m
66+
@inbounds begin
67+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
68+
Bu[i, j] = val
69+
Bl[j, i] = -adjoint(val)
70+
end
71+
end
72+
return
73+
end
74+
75+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
76+
m, n = size(Au)
77+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
78+
j > n && return
79+
for i in 1:m
80+
@inbounds begin
81+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
82+
Bu[i, j] = val
83+
Bl[j, i] = adjoint(val)
84+
end
85+
end
86+
return
87+
end
88+
89+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
90+
n = size(A, 1)
91+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
92+
j > n && return
93+
@inbounds begin
94+
for i in 1i32:(j - 1i32)
95+
val = (A[i, j] - adjoint(A[j, i])) / 2
96+
B[i, j] = val
97+
B[j, i] = -adjoint(val)
98+
end
99+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
100+
end
101+
return
102+
end
103+
104+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
105+
n = size(A, 1)
106+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
107+
j > n && return
108+
@inbounds begin
109+
for i in 1i32:(j - 1i32)
110+
val = (A[i, j] + adjoint(A[j, i])) / 2
111+
B[i, j] = val
112+
B[j, i] = adjoint(val)
113+
end
114+
B[j, j] = real(A[j, j])
115+
end
116+
return
117+
end
118+
119+
function MatrixAlgebraKit._project_hermitian_offdiag!(
120+
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
121+
) where {anti}
122+
thread_dim = 512
123+
block_dim = cld(size(Au, 2), thread_dim)
124+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
125+
return nothing
126+
end
127+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
128+
thread_dim = 512
129+
block_dim = cld(size(A, 1), thread_dim)
130+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
131+
return nothing
132+
end
133+
134+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
135+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
136+
137+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
138+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
139+
140+
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
141+
axes(A) == axes(B) || throw(DimensionMismatch())
142+
function _avgdiff_kernel(A, B)
143+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
144+
j > length(A) && return
145+
@inbounds begin
146+
a = A[j]
147+
b = B[j]
148+
A[j] = (a + b) / 2
149+
B[j] = b - a
150+
end
151+
return
152+
end
153+
thread_dim = 512
154+
block_dim = cld(length(A), thread_dim)
155+
@cuda threads = thread_dim blocks = block_dim _avgdiff_kernel(A, B)
156+
return A, B
157+
end
158+
61159
end

test/amd/projections.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using AMDGPU
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = ROCArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
if LinearAlgebra.LAPACK.version() < v"3.12.0"
55+
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
56+
else
57+
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi())
58+
end
59+
algs = (PolarViaSVD.(svdalgs)..., PolarNewton())
60+
@testset "algorithm $alg" for alg in algs
61+
A = ROCArray(randn(rng, T, m, n))
62+
W = project_isometric(A, alg)
63+
@test isisometric(W)
64+
W2 = project_isometric(W, alg)
65+
@test W2 W # stability of the projection
66+
@test W * (W' * A) A
67+
68+
Ac = similar(A)
69+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
70+
@test W2 === W
71+
@test isisometric(W)
72+
73+
# test that W is closer to A then any other isometry
74+
for k in 1:10
75+
δA = ROCArray(randn(rng, T, m, n))
76+
W = project_isometric(A, alg)
77+
W2 = project_isometric(A + δA / 100, alg)
78+
@test norm(A - W2) > norm(A - W)
79+
end
80+
end
81+
end
82+
end

test/cuda/projections.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using CUDA
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = CuArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
55+
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
56+
@testset "algorithm $alg" for alg in algs
57+
A = CuArray(randn(rng, T, m, n))
58+
W = project_isometric(A, alg)
59+
@test isisometric(W)
60+
W2 = project_isometric(W, alg)
61+
@test W2 W # stability of the projection
62+
@test W * (W' * A) A
63+
64+
Ac = similar(A)
65+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66+
@test W2 === W
67+
@test isisometric(W)
68+
69+
# test that W is closer to A then any other isometry
70+
for k in 1:10
71+
δA = CuArray(randn(rng, T, m, n))
72+
W = project_isometric(A, alg)
73+
W2 = project_isometric(A + δA / 100, alg)
74+
@test norm(A - W2) > norm(A - W)
75+
end
76+
end
77+
end
78+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ if CUDA.functional()
6363
@safetestset "CUDA LQ" begin
6464
include("cuda/lq.jl")
6565
end
66+
@safetestset "CUDA Projections" begin
67+
include("cuda/projections.jl")
68+
end
6669
@safetestset "CUDA SVD" begin
6770
include("cuda/svd.jl")
6871
end
@@ -82,6 +85,9 @@ if AMDGPU.functional()
8285
@safetestset "AMDGPU LQ" begin
8386
include("amd/lq.jl")
8487
end
88+
@safetestset "AMDGPU Projections" begin
89+
include("amd/projections.jl")
90+
end
8591
@safetestset "AMDGPU SVD" begin
8692
include("amd/svd.jl")
8793
end

0 commit comments

Comments
 (0)