Skip to content

Commit 92f576a

Browse files
kshyattlkdvosJutho
authored
Make svd_pullback! GPU-compatible (#232)
* Fix and test AD rules for SVD * Missing CUDA functional check * Comment response * Try to fix dimension mismatch issue * Formatter * Init maximum in case of empty array * Typo * Get rid of collect in Enzyme ext * Views seem to break stuff * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Missing brace * Update src/pullbacks/svd.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com> Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 5677370 commit 92f576a

7 files changed

Lines changed: 45 additions & 14 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTrun
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
1010
import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx!
11-
import MatrixAlgebraKit: _sylvester, svd_rank
11+
import MatrixAlgebraKit: _sylvester, svd_rank, svd_pullback!
1212
using AMDGPU
1313
using LinearAlgebra
1414
using LinearAlgebra: BlasFloat
@@ -185,6 +185,12 @@ function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
185185
return ROCArray(hX)
186186
end
187187

188-
svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s rank_atol, S)
188+
function svd_rank(S::AnyROCVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S))
189+
return something(findlast((rank_atol), S), 0)
190+
end
191+
192+
function svd_pullback!(ΔA::AnyROCMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyROCVector; kwargs...)
193+
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...)
194+
end
189195

190196
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
1010
import MatrixAlgebraKit: heevj!, heevd!, geev!
11-
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
11+
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!
1212
using CUDA, CUDA.cuBLAS
1313
using CUDA: i32
1414
using LinearAlgebra
@@ -197,6 +197,12 @@ function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
197197
return CuArray(hX)
198198
end
199199

200-
svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s rank_atol, S)
200+
function svd_rank(S::AnyCuVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S))
201+
return something(findlast((rank_atol), S), 0)
202+
end
203+
204+
function svd_pullback!(ΔA::AnyCuMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyCuVector; kwargs...)
205+
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...)
206+
end
201207

202208
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@ using LinearAlgebra
1515

1616
Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent
1717

18+
# needed for GPU tests because Mooncake can't differentiate through CUDA kernels
19+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray}
20+
function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
21+
A, dA = arrayify(A_dA)
22+
Ac = copy(A)
23+
zero!(A)
24+
function zero_adjoint(::NoRData)
25+
copy!(A, Ac)
26+
zero!(dA)
27+
return NoRData(), NoRData()
28+
end
29+
return A_dA, zero_adjoint
30+
end
31+
1832
# two-argument in-place factorizations like LQ, QR, EIG
1933
for (f!, f, pb, adj) in (
2034
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

src/pullbacks/svd.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,17 @@ function check_and_prepare_svd_cotangents(
8585
bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v
8686
return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v)
8787
end
88-
Δgauge = max(Δgauge, norm(bc, Inf))
88+
Δgauge = max(Δgauge, maximum(abs, bc))
8989

9090
if !iszerotangent(ΔSmat)
9191
ΔS = diagview(ΔSmat)
9292
length(indS) == length(ΔS) || throw(DimensionMismatch(lazy"length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))"))
93+
bad_indS = _ind_intersect((r + 1):length(ΔS), indS)
94+
good_indS = _ind_intersect(1:r, indS)
9395
ΔS₁ = zero(S₁)
94-
for (j, i) in enumerate(indS)
95-
if i <= r
96-
ΔS₁[i] = real(ΔS[j])
97-
else
98-
Δgauge = max(Δgauge, abs(ΔS[j]))
99-
end
100-
end
96+
ΔS₁[1:length(good_indS)] .= real.(ΔS[good_indS])
97+
badΔS₁ = view(ΔS, bad_indS)
98+
Δgauge = max(Δgauge, maximum(abs, badΔS₁; init = abs(zero(eltype(ΔS)))))
10199
else
102100
ΔS₁ = nothing
103101
end

test/mooncake/svd.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
2020
TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
2121
end
2222
end
23+
if T BLASFloats && CUDA.functional()
24+
TestSuite.test_mooncake_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
25+
if m == n
26+
AT = Diagonal{T, CuVector{T}}
27+
TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
28+
end
29+
end
2330
end

test/testsuite/enzyme/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function test_enzyme_svd_trunc(
6969
end
7070
@testset "trunctol" begin
7171
S = svd_vals(A, alg)
72-
trunc = trunctol(atol = S[1] / 2)
72+
trunc = trunctol(atol = maximum(S) / 2)
7373
truncalg = TruncatedAlgorithm(alg, trunc)
7474
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
7575
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)

test/testsuite/mooncake/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ function test_mooncake_svd_trunc(
141141

142142
@testset "trunctol" begin
143143
S = svd_vals(A)
144-
trunc = trunctol(atol = S[1] / 2)
144+
trunc = trunctol(atol = maximum(S) / 2)
145145
alg_trunc = TruncatedAlgorithm(alg, trunc)
146146

147147
USVᴴ, USVᴴtrunc, ΔUSVᴴ_arrays, ΔUSVᴴtrunc_arrays = ad_svd_trunc_setup(A, alg_trunc)

0 commit comments

Comments
 (0)