Skip to content

Commit 15feb46

Browse files
author
Katharine Hyatt
committed
Updates for AMD
1 parent 9255cfc commit 15feb46

1 file changed

Lines changed: 118 additions & 0 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,31 @@ end
2424
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
2525
return ROCSOLVER_QRIteration(; kwargs...)
2626
end
27+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
28+
throw(ErrorException("AMDGPU has no support for general eigensolving"))
29+
end
2730
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2831
return ROCSOLVER_DivideAndConquer(; kwargs...)
2932
end
3033

34+
# include for block sector support
35+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
36+
return ROCSOLVER_HouseholderQR(; kwargs...)
37+
end
38+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
39+
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
40+
return LQViaTransposedQR(qr_alg)
41+
end
42+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
43+
return ROCSOLVER_Jacobi(; kwargs...)
44+
end
45+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
46+
throw(ErrorException("AMDGPU has no support for general eigensolving"))
47+
end
48+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
49+
return ROCSOLVER_DivideAndConquer(; kwargs...)
50+
end
51+
3152
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = ishermitian(A)
3253

3354
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
@@ -54,4 +75,101 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::
5475
return MatrixAlgebraKit.findtruncated(values, strategy)
5576
end
5677

78+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
79+
m, n = size(Au)
80+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
81+
j > n && return
82+
for i in 1:m
83+
@inbounds begin
84+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
85+
Bu[i, j] = val
86+
Bl[j, i] = -adjoint(val)
87+
end
88+
end
89+
return
90+
end
91+
92+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
93+
m, n = size(Au)
94+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
95+
j > n && return
96+
for i in 1:m
97+
@inbounds begin
98+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
99+
Bu[i, j] = val
100+
Bl[j, i] = adjoint(val)
101+
end
102+
end
103+
return
104+
end
105+
106+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
107+
n = size(A, 1)
108+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
109+
j > n && return
110+
@inbounds begin
111+
for i in 1:(j - 1)
112+
val = (A[i, j] - adjoint(A[j, i])) /2
113+
B[i, j] = val
114+
B[j, i] = -adjoint(val)
115+
end
116+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
117+
end
118+
return
119+
end
120+
121+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
122+
n = size(A, 1)
123+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
124+
j > n && return
125+
@inbounds begin
126+
for i in 1:(j - 1)
127+
val = (A[i, j] + adjoint(A[j, i])) / 2
128+
B[i, j] = val
129+
B[j, i] = adjoint(val)
130+
end
131+
B[j, j] = real(A[j, j])
132+
end
133+
return
134+
end
135+
136+
function MatrixAlgebraKit._project_hermitian_offdiag!(
137+
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
138+
) where {anti}
139+
thread_dim = 512
140+
block_dim = cld(size(Au, 2), thread_dim)
141+
@roc groupsize=thread_dim gridsize=block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
142+
return nothing
143+
end
144+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
145+
thread_dim = 512
146+
block_dim = cld(size(A, 1), thread_dim)
147+
@roc groupsize=thread_dim gridsize=block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
148+
return nothing
149+
end
150+
151+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all( A .== adjoint(A))
152+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all( A.diag .== adjoint(A.diag))
153+
154+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all( A .== -adjoint(A))
155+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all( A.diag .== -adjoint(A.diag))
156+
157+
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
158+
axes(A) == axes(B) || throw(DimensionMismatch())
159+
function _avgdiff_kernel(A, B)
160+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
161+
j > length(A) && return
162+
@inbounds begin
163+
a = A[j]
164+
b = B[j]
165+
A[j] = (a+b)/2
166+
B[j] = b - a
167+
end
168+
return
169+
end
170+
thread_dim = 512
171+
block_dim = cld(length(A), thread_dim)
172+
@cuda groupsize=thread_dim gridsize=block_dim _avgdiff_kernel(A, B)
173+
return A, B
174+
end
57175
end

0 commit comments

Comments
 (0)