Skip to content

Commit b919563

Browse files
author
Katharine Hyatt
committed
Some basic svd forward rules and tests
1 parent 742572f commit b919563

7 files changed

Lines changed: 263 additions & 16 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using MatrixAlgebraKit: qr_pullback!, lq_pullback!
77
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
88
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
99
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
10+
using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward!
1011
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1112
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
1213
using Enzyme
@@ -257,6 +258,30 @@ for f in (:svd_compact!, :svd_full!)
257258
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
258259
return (nothing, nothing, nothing)
259260
end
261+
function EnzymeRules.forward(
262+
config::EnzymeRules.FwdConfigWidth{1},
263+
func::Const{typeof($f)},
264+
::Type{RT},
265+
A::Annotation,
266+
USVᴴ::Annotation,
267+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
268+
) where {RT}
269+
$f(A.val, USVᴴ.val, alg.val)
270+
if !isa(A, Const) && !isa(USVᴴ, Const)
271+
make_zero!(USVᴴ.dval)
272+
svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
273+
end
274+
#!isa(A, Const) && make_zero!(A.dval)
275+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
276+
return USVᴴ
277+
elseif EnzymeRules.needs_primal(config)
278+
return USVᴴ.val
279+
elseif EnzymeRules.needs_shadow(config)
280+
return USVᴴ.dval
281+
else
282+
return nothing
283+
end
284+
end
260285
end
261286
end
262287

@@ -467,5 +492,32 @@ function EnzymeRules.reverse(
467492
!isa(S, Const) && !A_is_arg && make_zero!(S.dval)
468493
return (nothing, nothing, nothing)
469494
end
495+
function EnzymeRules.forward(
496+
config::EnzymeRules.FwdConfigWidth{1},
497+
func::Const{typeof(svd_vals!)},
498+
::Type{RT},
499+
A::Annotation{TA},
500+
S::Annotation,
501+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
502+
) where {RT, TA}
503+
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
504+
U, S_, Vᴴ = svd_compact!(A.val, alg.val)
505+
copyto!(S.val, diagview(S_))
506+
if !isa(A, Const) && !isa(S, Const)
507+
ΔS = A_is_arg ? make_zero(S.dval) : S.dval
508+
svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(S.val), Vᴴ), ΔS)
509+
A_is_arg && (S.dval .= ΔS)
510+
end
511+
!isa(A, Const) && !A_is_arg && make_zero!(A.dval)
512+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
513+
return S
514+
elseif EnzymeRules.needs_primal(config)
515+
return S.val
516+
elseif EnzymeRules.needs_shadow(config)
517+
return S.dval
518+
else
519+
return nothing
520+
end
521+
end
470522

471523
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul
1111
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
1313
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
14+
using MatrixAlgebraKit: svd_pushforward!, svd_trunc_pushforward!, svd_vals_pushforward!
1415
using MatrixAlgebraKit: TruncatedAlgorithm
1516
using LinearAlgebra
1617

@@ -511,7 +512,7 @@ for (f!, f) in (
511512
(:svd_compact!, :svd_compact),
512513
)
513514
@eval begin
514-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
515+
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
515516
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
516517
A, dA = arrayify(A_dA)
517518
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
@@ -535,7 +536,18 @@ for (f!, f) in (
535536
end
536537
return USVᴴ_dUSVᴴ, svd_adjoint
537538
end
538-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
539+
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual)
540+
A, dA = arrayify(A_dA)
541+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
542+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
543+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
544+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
545+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
546+
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
547+
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
548+
return USVᴴ_dUSVᴴ
549+
end
550+
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
539551
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
540552
A, dA = arrayify(A_dA)
541553
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
@@ -558,10 +570,23 @@ for (f!, f) in (
558570
end
559571
return USVᴴ_codual, svd_adjoint
560572
end
573+
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
574+
A, dA = arrayify(A_dA)
575+
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
576+
dUSVᴴ = Mooncake.zero_tangent(USVᴴ)
577+
USVᴴ_dual = Dual(USVᴴ, dUSVᴴ)
578+
U, S, Vᴴ = Mooncake.primal(USVᴴ_dual)
579+
dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual)
580+
U, dU = arrayify(U, dU_)
581+
S, dS = arrayify(S, dS_)
582+
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
583+
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
584+
return USVᴴ_dual
585+
end
561586
end
562587
end
563588

564-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
589+
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
565590
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
566591
# compute primal
567592
A, dA = arrayify(A_dA)
@@ -577,8 +602,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
577602
end
578603
return S_dS, svd_vals_adjoint
579604
end
605+
function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual)
606+
# compute primal
607+
A, dA = arrayify(A_dA)
608+
S, dS = arrayify(S_dS)
609+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
610+
copy!(S, diagview(USVᴴ[2]))
611+
svd_vals_pushforward!(dA, A, USVᴴ, dS)
612+
return S_dS
613+
end
580614

581-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
615+
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
582616
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
583617
# compute primal
584618
A, dA = arrayify(A_dA)
@@ -597,6 +631,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
597631
end
598632
return S_codual, svd_vals_adjoint
599633
end
634+
function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual)
635+
# compute primal
636+
A, dA = arrayify(A_dA)
637+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
638+
S = diagview(USVᴴ[2])
639+
S_dual = Dual(S, Mooncake.zero_tangent(S))
640+
S_, dS = arrayify(S_dual)
641+
svd_vals_pushforward!(dA, A, USVᴴ, dS)
642+
return S_dual
643+
end
600644

601645
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
602646
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ include("pullbacks/svd.jl")
130130
include("pullbacks/polar.jl")
131131

132132
include("pushforwards/polar.jl")
133+
include("pushforwards/svd.jl")
133134

134135
include("precompile.jl")
135136

src/pushforwards/svd.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...)
2+
U, Smat, Vᴴ = USVᴴ
3+
m, n = size(U, 1), size(Vᴴ, 2)
4+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
5+
minmn = min(m, n)
6+
S = diagview(Smat)
7+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
8+
r = searchsortedlast(S, rank_atol; rev = true) # rank
9+
10+
vΔS = view(ΔS, 1:r, 1:r)
11+
12+
vU = view(U, :, 1:r)
13+
vS = view(S, 1:r)
14+
vSmat = view(Smat, 1:r, 1:r)
15+
vVᴴ = view(Vᴴ, 1:r, :)
16+
17+
# compact region
18+
vV = adjoint(vVᴴ)
19+
UΔAV = vU' * ΔA * vV
20+
copyto!(diagview(vΔS), diag(real.(UΔAV)))
21+
F = one(eltype(S)) ./ (transpose(vS) .- vS)
22+
G = one(eltype(S)) ./ (transpose(vS) .+ vS)
23+
diagview(F) .= zero(eltype(F))
24+
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
25+
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
26+
= hUΔAV + aUΔAV
27+
= hUΔAV - aUΔAV
28+
29+
# check gauge condition
30+
@assert isantihermitian(K̇)
31+
@assert isantihermitian(Ṁ)
32+
K̇diag = diagview(K̇)
33+
for i in 1:length(K̇diag)
34+
@assert K̇diag[i] (im / 2) * imag(diagview(UΔAV)[i]) / S[i]
35+
end
36+
37+
∂U = vU *
38+
∂V = vV *
39+
# full component
40+
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
41+
Uperp = view(U, :, (minmn + 1):m)
42+
Vᴴperp = view(Vᴴ, (minmn + 1):n, :)
43+
44+
aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)
45+
46+
UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
47+
fill!(UÃÃV, 0)
48+
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
49+
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
50+
rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U)
51+
superKM = -sylvester(UÃÃV, Smat, rhs)
52+
K̇perp = view(superKM, 1:size(aUAV, 2))
53+
Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
54+
∂U .+= Uperp * K̇perp
55+
∂V .+= Vᴴperp * Ṁperp
56+
else
57+
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU')
58+
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ)
59+
upper = ImUU * ΔA * vV
60+
lower = ImVV * ΔA' * vU
61+
rhs = vcat(upper, lower)
62+
63+
= ImUU * A * ImVV
64+
ÃÃ = similar(A, (m + n, m + n))
65+
fill!(ÃÃ, 0)
66+
view(ÃÃ, (1:m), m .+ (1:n)) .=
67+
view(ÃÃ, m .+ (1:n), 1:m) .='
68+
69+
superLN = -sylvester(ÃÃ, vSmat, rhs)
70+
∂U += view(superLN, 1:size(upper, 1), :)
71+
∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :)
72+
end
73+
if !iszerotangent(ΔU)
74+
vΔU = view(ΔU, :, 1:r)
75+
copyto!(vΔU, ∂U)
76+
end
77+
if !iszerotangent(ΔVᴴ)
78+
vΔVᴴ = view(ΔVᴴ, 1:r, :)
79+
adjoint!(vΔVᴴ, ∂V)
80+
end
81+
return (ΔU, ΔS, ΔVᴴ)
82+
end
83+
84+
function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
85+
# TODO
86+
end
87+
88+
function svd_vals_pushforward!(
89+
ΔA, A, USVᴴ, ΔS, ind = Colon();
90+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
91+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
92+
)
93+
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
94+
return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
95+
end

test/enzyme/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
AT = Diagonal{T, Vector{T}}
19-
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
19+
m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
2020
end
2121
end

test/testsuite/enzyme/svd.jl

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,90 @@
11
function test_enzyme_svd(T::Type, sz; kwargs...)
22
summary_str = testargs_summary(T, sz)
33
return @testset "Enzyme svd $summary_str" begin
4-
test_enzyme_svd_compact(T, sz; kwargs...)
5-
test_enzyme_svd_full(T, sz; kwargs...)
4+
#test_enzyme_svd_compact(T, sz; kwargs...)
5+
#test_enzyme_svd_full(T, sz; kwargs...)
66
test_enzyme_svd_vals(T, sz; kwargs...)
7-
test_enzyme_svd_trunc(T, sz; kwargs...)
7+
#test_enzyme_svd_trunc(T, sz; kwargs...)
88
end
99
end
1010

11+
"""
12+
test_enzyme_svd_compact(T, sz; rng, atol, rtol)
13+
14+
Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant.
15+
"""
1116
function test_enzyme_svd_compact(
1217
T, sz;
1318
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
1419
fdm = enzyme_fdm(T)
1520
)
16-
return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
21+
return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
1722
A = instantiate_matrix(T, sz)
1823
alg = MatrixAlgebraKit.select_algorithm(svd_compact, A)
1924
USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A)
2025
test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
2126
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
27+
if eltype(T) <: Real
28+
A = instantiate_matrix(T, sz)
29+
test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm)
30+
test_forward(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
31+
end
2232
end
2333
end
2434

35+
"""
36+
test_enzyme_svd_full(T, sz; rng, atol, rtol)
37+
38+
Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The
39+
gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent.
40+
"""
2541
function test_enzyme_svd_full(
2642
T, sz;
2743
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
2844
fdm = enzyme_fdm(T)
2945
)
30-
return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
46+
return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
3147
A = instantiate_matrix(T, sz)
3248
alg = MatrixAlgebraKit.select_algorithm(svd_full, A)
3349
USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A)
3450
test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
3551
test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
52+
if eltype(T) <: Real
53+
A = instantiate_matrix(T, sz)
54+
test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
55+
test_forward(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
56+
end
3657
end
3758
end
3859

60+
"""
61+
test_enzyme_svd_vals(T, sz; rng, atol, rtol)
62+
63+
Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant.
64+
"""
3965
function test_enzyme_svd_vals(
4066
T, sz;
4167
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
4268
fdm = enzyme_fdm(T)
4369
)
44-
return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
70+
return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
4571
A = instantiate_matrix(T, sz)
4672
alg = MatrixAlgebraKit.select_algorithm(svd_vals, A)
4773
S, ΔS = ad_svd_vals_setup(A)
4874
test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
4975
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
76+
A = instantiate_matrix(T, sz)
77+
test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm)
78+
test_forward(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
5079
end
5180
end
5281

82+
"""
83+
test_enzyme_svd_trunc(T, sz; rng, atol, rtol)
84+
85+
Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their
86+
in-place variants, over a range of truncation ranks and a tolerance-based truncation.
87+
"""
5388
function test_enzyme_svd_trunc(
5489
T, sz;
5590
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),

0 commit comments

Comments
 (0)