Skip to content

Commit 3a98285

Browse files
authored
Working pullback and tests for EIGH + CUDA (#235)
1 parent 92f576a commit 3a98285

3 files changed

Lines changed: 13 additions & 3 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
1010
import MatrixAlgebraKit: heevj!, heevd!, geev!
11-
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!
11+
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!
1212
using CUDA, CUDA.cuBLAS
1313
using CUDA: i32
1414
using LinearAlgebra
@@ -205,4 +205,8 @@ function svd_pullback!(ΔA::AnyCuMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyCuVector;
205205
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...)
206206
end
207207

208+
function eigh_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs...)
209+
return eigh_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...)
210+
end
211+
208212
end

src/pullbacks/eigh.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function check_and_prepare_eigh_cotangents(
3131
bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v
3232
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
3333
end
34-
Δgauge = norm(bc, Inf)
34+
Δgauge = maximum(abs, bc; init = abs(zero(eltype(D))))
3535

3636
Δgauge gauge_atol ||
3737
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@@ -42,7 +42,8 @@ function check_and_prepare_eigh_cotangents(
4242
if !iszerotangent(ΔDmat)
4343
ΔD = diagview(ΔDmat)
4444
length(indD) == length(ΔD) || throw(DimensionMismatch())
45-
view(diagview(VᴴAΔV), indD) .+= real.(ΔD)
45+
# needed to avoid GPUCompiler errors
46+
VᴴAΔV[diagind(VᴴAΔV)[indD]] .+= real.(ΔD)
4647
else
4748
ΔD = nothing
4849
end

test/mooncake/eigh.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@ for T in (BLASFloats..., GenericFloats...)
1818
AT = Diagonal{T, Vector{T}}
1919
TestSuite.test_mooncake_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
2020
end
21+
if T BLASFloats && CUDA.functional()
22+
TestSuite.test_mooncake_eigh(CuMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
23+
AT = Diagonal{T, CuVector{T}}
24+
TestSuite.test_mooncake_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
25+
end
2126
end

0 commit comments

Comments
 (0)