Skip to content

Commit 584e948

Browse files
authored
Forward AD rules and tests for polar decompositions (#242)
* Add Mooncake fwd rules for polar * Forward rules and tests for Enzyme + polar * Use _sylvester fallback * Comments and fix * Apply Jutho suggestions * Pushforward improvements * even more optimization
1 parent 33d77fd commit 584e948

6 files changed

Lines changed: 116 additions & 16 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
88
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
99
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
1010
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11+
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
1112
using Enzyme
1213
using Enzyme.EnzymeCore
1314
using Enzyme.EnzymeCore: EnzymeRules
@@ -117,6 +118,36 @@ for (f, pb) in (
117118
end
118119
end
119120

121+
for (f, pf) in (
122+
(left_polar!, left_polar_pushforward!),
123+
(right_polar!, right_polar_pushforward!),
124+
)
125+
@eval begin
126+
function EnzymeRules.forward(
127+
config::EnzymeRules.FwdConfigWidth{1},
128+
func::Const{typeof($f)},
129+
::Type{RT},
130+
A::Annotation,
131+
arg::Annotation{TA},
132+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
133+
) where {RT, TA}
134+
$f(A.val, arg.val, alg.val)
135+
if !isa(A, Const) && !isa(arg, Const)
136+
$pf(A.dval, A.val, arg.val, arg.dval)
137+
end
138+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
139+
return arg
140+
elseif EnzymeRules.needs_primal(config)
141+
return arg.val
142+
elseif EnzymeRules.needs_shadow(config)
143+
return arg.dval
144+
else
145+
return nothing
146+
end
147+
end
148+
end
149+
end
150+
120151
for (f, pb) in (
121152
(qr_null!, qr_null_pullback!),
122153
(lq_null!, lq_null_pullback!),

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
12+
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
1213
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1314
using MatrixAlgebraKit: TruncatedAlgorithm
1415
using LinearAlgebra
1516

1617
Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent
1718

1819
# needed for GPU tests because Mooncake can't differentiate through CUDA kernels
19-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray}
20+
@is_primitive Mooncake.DefaultCtx Tuple{typeof(zero!), AbstractArray}
2021
function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
2122
A, dA = arrayify(A_dA)
2223
Ac = copy(A)
@@ -28,6 +29,12 @@ function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
2829
end
2930
return A_dA, zero_adjoint
3031
end
32+
function Mooncake.frule!!(::Dual{typeof(zero!)}, A_dA::Dual)
33+
A, dA = arrayify(A_dA)
34+
zero!(A)
35+
zero!(dA)
36+
return A_dA
37+
end
3138

3239
# two-argument in-place factorizations like LQ, QR, EIG
3340
for (f!, f, pb, adj) in (
@@ -40,7 +47,6 @@ for (f!, f, pb, adj) in (
4047
(:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint),
4148
(:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint),
4249
)
43-
4450
@eval begin
4551
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
4652
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
@@ -104,6 +110,36 @@ for (f!, f, pb, adj) in (
104110
end
105111
end
106112

113+
for (f!, f, pf) in (
114+
(:left_polar!, :left_polar, :left_polar_pushforward!),
115+
(:right_polar!, :right_polar, :right_polar_pushforward!),
116+
)
117+
@eval begin
118+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
119+
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
120+
A, dA = arrayify(A_dA)
121+
args = Mooncake.primal(args_dargs)
122+
dargs = Mooncake.tangent(args_dargs)
123+
arg1, darg1 = arrayify(args[1], dargs[1])
124+
arg2, darg2 = arrayify(args[2], dargs[2])
125+
$f!(A, args, Mooncake.primal(alg_dalg))
126+
$pf(dA, A, (arg1, arg2), (darg1, darg2))
127+
return args_dargs
128+
end
129+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
130+
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
131+
A, dA = arrayify(A_dA)
132+
output = $f(A, Mooncake.primal(alg_dalg))
133+
doutput = Mooncake.zero_tangent(output)
134+
output_dual = Dual(output, doutput)
135+
arg1, darg1 = arrayify(output[1], doutput[1])
136+
arg2, darg2 = arrayify(output[2], doutput[2])
137+
$pf(dA, A, (arg1, arg2), (darg1, darg2))
138+
return output_dual
139+
end
140+
end
141+
end
142+
107143
for (f!, f, pb, adj) in (
108144
(:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint),
109145
(:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint),

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ include("pullbacks/eigh.jl")
129129
include("pullbacks/svd.jl")
130130
include("pullbacks/polar.jl")
131131

132+
include("pushforwards/polar.jl")
133+
132134
include("precompile.jl")
133135

134136
end

src/pushforwards/polar.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
2+
W, P = WP
3+
ΔW, ΔP = ΔWP
4+
mul!(ΔP, adjoint(W), ΔA, +1, 0)
5+
= _sylvester(P, P, adjoint(ΔP) - ΔP)
6+
mul!(ΔW, ΔA, inv(P), +1, 0)
7+
WᴴdAiP = W' * ΔW
8+
mul!(ΔW, W, WᴴdAiP, -1, +1)
9+
ΔW = mul!(ΔW, W, K̇, +1, +1)
10+
ΔP = mul!(ΔP, K̇, P, -1, +1)
11+
return (ΔW, ΔP)
12+
end
13+
14+
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
15+
P, Wᴴ = PWᴴ
16+
ΔP, ΔWᴴ = ΔPWᴴ
17+
mul!(ΔP, ΔA, adjoint(Wᴴ), +1, 0)
18+
= _sylvester(P, P, adjoint(ΔP) - ΔP)
19+
mul!(ΔWᴴ, inv(P), ΔA, +1, 0)
20+
iPdAW = ΔWᴴ * Wᴴ'
21+
mul!(ΔWᴴ, iPdAW, Wᴴ, -1, +1)
22+
ΔWᴴ = mul!(ΔWᴴ, K̇, Wᴴ, +1, +1)
23+
ΔP = mul!(ΔP, P, K̇, -1, +1)
24+
return (ΔWᴴ, ΔP)
25+
end

test/testsuite/enzyme/polar.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,49 @@ end
1414
"""
1515
test_enzyme_left_polar(T, sz; rng, atol, rtol)
1616
17-
Test the Enzyme reverse-mode AD rule for `left_polar` and its in-place variant. Only runs
18-
for tall or square matrices (`m >= n`).
17+
Test the Enzyme forward- and reverse-mode AD rule for `left_polar` and its in-place variant.
18+
Only runs for tall or square matrices (`m >= n`).
1919
"""
2020
function test_enzyme_left_polar(
2121
T, sz;
2222
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
2323
)
24-
return @testset "left_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
24+
return @testset "left_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2525
A = instantiate_matrix(T, sz)
2626
m, n = size(A)
2727
if m >= n
2828
alg = MatrixAlgebraKit.select_algorithm(left_polar, A)
2929
WP, ΔWP = ad_left_polar_setup(A)
3030
test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol)
3131
test_reverse(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol)
32+
A = instantiate_matrix(T, sz)
33+
test_forward(left_polar, RT, (A, TA), (alg, Const); atol, rtol)
34+
test_forward(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol)
3235
end
3336
end
3437
end
3538

3639
"""
3740
test_enzyme_right_polar(T, sz; rng, atol, rtol)
3841
39-
Test the Enzyme reverse-mode AD rule for `right_polar` and its in-place variant. Only runs
40-
for wide or square matrices (`m <= n`).
42+
Test the Enzyme forward- and reverse-mode AD rule for `right_polar` and its in-place variant.
43+
Only runs for wide or square matrices (`m <= n`).
4144
"""
4245
function test_enzyme_right_polar(
4346
T, sz;
4447
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
4548
)
46-
return @testset "right_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
49+
return @testset "right_polar: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
4750
A = instantiate_matrix(T, sz)
4851
m, n = size(A)
4952
if m <= n
5053
alg = MatrixAlgebraKit.select_algorithm(right_polar, A)
5154
PWᴴ, ΔPWᴴ = ad_right_polar_setup(A)
5255
test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol)
5356
test_reverse(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol)
57+
A = instantiate_matrix(T, sz)
58+
test_forward(right_polar, RT, (A, TA), (alg, Const); atol, rtol)
59+
test_forward(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol)
5460
end
5561
end
5662
end

test/testsuite/mooncake/polar.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ end
1414
"""
1515
test_mooncake_left_polar(T, sz; rng, atol, rtol)
1616
17-
Test the Mooncake reverse-mode AD rule for `left_polar` and its in-place variant. Only runs
18-
for tall or square matrices (`m >= n`).
17+
Test the Mooncake forward- and reverse-mode AD rule for `left_polar` and its in-place variant.
18+
Only runs for tall or square matrices (`m >= n`).
1919
"""
2020
function test_mooncake_left_polar(
2121
T, sz;
@@ -31,20 +31,20 @@ function test_mooncake_left_polar(
3131

3232
Mooncake.TestUtils.test_rule(
3333
rng, left_polar, A, alg;
34-
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
34+
output_tangent, atol, rtol
3535
)
3636
Mooncake.TestUtils.test_rule(
3737
rng, call_and_zero!, left_polar!, A, alg;
38-
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
38+
output_tangent, atol, rtol, is_primitive = false
3939
)
4040
end
4141
end
4242

4343
"""
4444
test_mooncake_right_polar(T, sz; rng, atol, rtol)
4545
46-
Test the Mooncake reverse-mode AD rule for `right_polar` and its in-place variant. Only runs
47-
for wide or square matrices (`m <= n`).
46+
Test the Mooncake forward- and reverse-mode AD rule for `right_polar` and its in-place variant.
47+
Only runs for wide or square matrices (`m <= n`).
4848
"""
4949
function test_mooncake_right_polar(
5050
T, sz;
@@ -60,11 +60,11 @@ function test_mooncake_right_polar(
6060

6161
Mooncake.TestUtils.test_rule(
6262
rng, right_polar, A, alg;
63-
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
63+
output_tangent, atol, rtol
6464
)
6565
Mooncake.TestUtils.test_rule(
6666
rng, call_and_zero!, right_polar!, A, alg;
67-
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
67+
output_tangent, atol, rtol, is_primitive = false
6868
)
6969
end
7070
end

0 commit comments

Comments
 (0)