Skip to content

Commit b47d076

Browse files
committed
Piracy fix
1 parent 6452b9e commit b47d076

7 files changed

Lines changed: 18 additions & 15 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
@@ -171,4 +171,9 @@ end
171171
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
172172
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
173173

174+
function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
175+
hX = sylvester(collect(A), collect(B), collect(C))
176+
return ROCArray(hX)
177+
end
178+
174179
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, defaul
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
@@ -202,14 +202,9 @@ function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
202202
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
203203
end
204204

205-
function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
206-
#=m = size(A, 1)
207-
n = size(B, 2)
208-
I_n = fill!(similar(A, n), one(eltype(A)))
209-
I_m = fill!(similar(B, m), one(eltype(B)))
210-
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
211-
x_vec = L \ -vec(C)
212-
X = CuMatrix(reshape(x_vec, m, n))=#
205+
function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
206+
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
207+
# to add native sylvester to CUDA
213208
hX = sylvester(collect(A), collect(B), collect(C))
214209
return CuArray(hX)
215210
end

src/common/pullbacks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ function iszerotangent end
1010

1111
iszerotangent(::Any) = false
1212
iszerotangent(::Nothing) = true
13+
14+
# fallback
15+
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)

src/pullbacks/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function eig_trunc_pullback!(
140140
# add contribution from orthogonal complement
141141
PA = A - (A * V) / V
142142
Y = mul!(ΔVperp, PA', Z, 1, 1)
143-
X = sylvester(PA', -Dmat', Y)
143+
X = _sylvester(PA', -Dmat', Y)
144144
Z .+= X
145145

146146
if eltype(ΔA) <: Real

src/pullbacks/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function eigh_trunc_pullback!(
142142
# add contribution from orthogonal complement
143143
W = qr_null(V)
144144
WᴴΔV = W' * ΔV
145-
X = sylvester(W' * A * W, -Dmat, WᴴΔV)
145+
X = _sylvester(W' * A * W, -Dmat, WᴴΔV)
146146
Z = mul!(Z, W, X, 1, 1)
147147

148148
# put everything together: symmetrize for hermitian case

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
1616
M = zero(P)
1717
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
1818
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
19-
C = sylvester(P, P, M' - M)
19+
C = _sylvester(P, P, M' - M)
2020
C .+= ΔP
2121
ΔA = mul!(ΔA, W, C, 1, 1)
2222
if !iszerotangent(ΔW)
@@ -52,7 +52,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
5252
M = zero(P)
5353
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
5454
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
55-
C = sylvester(P, P, M' - M)
55+
C = _sylvester(P, P, M' - M)
5656
C .+= ΔP
5757
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
5858
if !iszerotangent(ΔWᴴ)

src/pullbacks/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function svd_trunc_pullback!(
207207
else
208208
fill!(view(rhs, m̃ .+ (1:ñ), :), 0)
209209
end
210-
XY = sylvester(ÃÃ, -Smat, rhs)
210+
XY = _sylvester(ÃÃ, -Smat, rhs)
211211
X = view(XY, 1:m̃, :)
212212
Y = view(XY, m̃ .+ (1:ñ), :)
213213
ΔA = mul!(ΔA, Ũ, X * Vᴴ, 1, 1)

0 commit comments

Comments
 (0)