Skip to content

Commit 5474c5e

Browse files
lkdvosclaude
andcommitted
Add AD rules for projection methods
Add rrules/pullbacks for `project_hermitian!`, `project_antihermitian!`, and `project_isometric!` directly in each AD backend extension (ChainRulesCore, Enzyme, Mooncake). The hermitian/antihermitian pullbacks are self-adjoint, while the isometric pullback delegates to `left_polar_pullback!` with zero ΔP. Includes test utilities and tests for all three backends. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ab8aea1 commit 5474c5e

7 files changed

Lines changed: 362 additions & 0 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,46 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
274274
return PWᴴ, right_polar_pullback
275275
end
276276

277+
function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg)
278+
Ac = copy_input(project_hermitian, A)
279+
Aₕ = project_hermitian!(Ac, Aₕ, alg)
280+
function project_hermitian_pullback(ΔAₕ)
281+
ΔA = project_hermitian(unthunk(ΔAₕ))
282+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
283+
end
284+
function project_hermitian_pullback(::ZeroTangent)
285+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
286+
end
287+
return Aₕ, project_hermitian_pullback
288+
end
289+
290+
function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg)
291+
Ac = copy_input(project_antihermitian, A)
292+
Aₐ = project_antihermitian!(Ac, Aₐ, alg)
293+
function project_antihermitian_pullback(ΔAₐ)
294+
ΔA = project_antihermitian(unthunk(ΔAₐ))
295+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
296+
end
297+
function project_antihermitian_pullback(::ZeroTangent)
298+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
299+
end
300+
return Aₐ, project_antihermitian_pullback
301+
end
302+
303+
function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg)
304+
Ac = copy_input(project_isometric, A)
305+
# Compute the full polar decomposition to cache P for the pullback
306+
WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg)
307+
W_out = copy!(W, WP[1])
308+
function project_isometric_pullback(ΔW)
309+
ΔA = zero(A)
310+
MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing))
311+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
312+
end
313+
function project_isometric_pullback(::ZeroTangent)
314+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
315+
end
316+
return W_out, project_isometric_pullback
317+
end
318+
277319
end

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,4 +454,91 @@ function EnzymeRules.reverse(
454454
return (nothing, nothing, nothing)
455455
end
456456

457+
# single-output projections: project_hermitian!, project_antihermitian!
458+
# single-output projections: project_hermitian!, project_antihermitian!
459+
for (f!, project_f) in (
460+
(project_hermitian!, project_hermitian),
461+
(project_antihermitian!, project_antihermitian),
462+
)
463+
@eval begin
464+
function EnzymeRules.augmented_primal(
465+
config::EnzymeRules.RevConfigWidth{1},
466+
func::Const{typeof($f!)},
467+
::Type{RT},
468+
A::Annotation,
469+
arg::Annotation{TA},
470+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
471+
) where {RT, TA}
472+
ret = func.val(A.val, arg.val, alg.val)
473+
cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
474+
dret = if EnzymeRules.needs_shadow(config)
475+
(TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval
476+
else
477+
nothing
478+
end
479+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
480+
return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret))
481+
end
482+
function EnzymeRules.reverse(
483+
config::EnzymeRules.RevConfigWidth{1},
484+
func::Const{typeof($f!)},
485+
::Type{RT},
486+
cache,
487+
A::Annotation,
488+
arg::Annotation,
489+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
490+
) where {RT}
491+
cache_arg, darg = cache
492+
argdval = something(darg, arg.dval)
493+
if !isa(A, Const)
494+
A.dval .+= $project_f(argdval)
495+
end
496+
!isa(arg, Const) && make_zero!(arg.dval)
497+
return (nothing, nothing, nothing)
498+
end
499+
end
500+
end
501+
502+
# project_isometric! needs special handling: compute full polar decomposition
503+
function EnzymeRules.augmented_primal(
504+
config::EnzymeRules.RevConfigWidth{1},
505+
func::Const{typeof(project_isometric!)},
506+
::Type{RT},
507+
A::Annotation,
508+
W::Annotation{TW},
509+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
510+
) where {RT, TW}
511+
# Compute the full polar decomposition for the pullback
512+
Ac = copy(A.val)
513+
m, n = size(A.val)
514+
P = similar(A.val, n, n)
515+
WP = left_polar!(Ac, (W.val, P), alg.val)
516+
cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing
517+
dret = if EnzymeRules.needs_shadow(config)
518+
(TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval
519+
else
520+
nothing
521+
end
522+
primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing
523+
return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret))
524+
end
525+
function EnzymeRules.reverse(
526+
config::EnzymeRules.RevConfigWidth{1},
527+
func::Const{typeof(project_isometric!)},
528+
::Type{RT},
529+
cache,
530+
A::Annotation,
531+
W::Annotation,
532+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
533+
) where {RT}
534+
cache_WP, dW = cache
535+
Aval = nothing
536+
WPval = something(cache_WP, (W.val, cache_WP[2]))
537+
if !isa(A, Const)
538+
left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing))
539+
end
540+
!isa(W, Const) && make_zero!(W.dval)
541+
return (nothing, nothing, nothing)
542+
end
543+
457544
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,4 +779,82 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
779779
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
780780
end
781781

782+
# single-output projections: project_hermitian!, project_antihermitian!
783+
# single-output projections: project_hermitian!, project_antihermitian!
784+
for (f!, f, adj) in (
785+
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
786+
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
787+
)
788+
@eval begin
789+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
790+
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
791+
A, dA = arrayify(A_dA)
792+
Ac = copy(A)
793+
arg, darg = arrayify(arg_darg)
794+
argc = copy(arg)
795+
$f!(A, arg, Mooncake.primal(alg_dalg))
796+
function $adj(::NoRData)
797+
copy!(A, Ac)
798+
dA .+= $f(darg)
799+
copy!(arg, argc)
800+
zero!(darg)
801+
return NoRData(), NoRData(), NoRData(), NoRData()
802+
end
803+
return arg_darg, $adj
804+
end
805+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
806+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
807+
A, dA = arrayify(A_dA)
808+
output = $f(A, Mooncake.primal(alg_dalg))
809+
output_codual = CoDual(output, Mooncake.zero_tangent(output))
810+
function $adj(::NoRData)
811+
arg, darg = arrayify(output_codual)
812+
dA .+= $f(darg)
813+
zero!(darg)
814+
return NoRData(), NoRData(), NoRData()
815+
end
816+
return output_codual, $adj
817+
end
818+
end
819+
end
820+
821+
# project_isometric! needs special handling: compute full polar decomposition
822+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
823+
function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
824+
A, dA = arrayify(A_dA)
825+
W, dW = arrayify(W_dW)
826+
Ac = copy(A)
827+
Wc = copy(W)
828+
# Compute the full polar decomposition for the pullback
829+
m, n = size(A)
830+
P = similar(A, n, n)
831+
WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg))
832+
copy!(W, WP[1])
833+
function project_isometric_adjoint(::NoRData)
834+
copy!(A, Ac)
835+
left_polar_pullback!(dA, A, WP, (dW, nothing))
836+
copy!(W, Wc)
837+
zero!(dW)
838+
return NoRData(), NoRData(), NoRData(), NoRData()
839+
end
840+
return W_dW, project_isometric_adjoint
841+
end
842+
843+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm}
844+
function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
845+
A, dA = arrayify(A_dA)
846+
alg = Mooncake.primal(alg_dalg)
847+
# Compute the full polar decomposition for the pullback
848+
WP = left_polar(A, alg)
849+
W_out = WP[1]
850+
output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out))
851+
function project_isometric_adjoint(::NoRData)
852+
W, dW = arrayify(output_codual)
853+
left_polar_pullback!(dA, A, WP, (dW, nothing))
854+
zero!(dW)
855+
return NoRData(), NoRData(), NoRData()
856+
end
857+
return output_codual, project_isometric_adjoint
858+
end
859+
782860
end

test/testsuite/ad_utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,27 @@ function ad_right_null_setup(A)
421421
ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2]
422422
return Nᴴ, ΔNᴴ
423423
end
424+
425+
function ad_project_hermitian_setup(A)
426+
m, n = size(A)
427+
T = eltype(A)
428+
Aₕ = project_hermitian(A)
429+
ΔAₕ = randn!(similar(A, T, m, n))
430+
return Aₕ, ΔAₕ
431+
end
432+
433+
function ad_project_antihermitian_setup(A)
434+
m, n = size(A)
435+
T = eltype(A)
436+
Aₐ = project_antihermitian(A)
437+
ΔAₐ = randn!(similar(A, T, m, n))
438+
return Aₐ, ΔAₐ
439+
end
440+
441+
function ad_project_isometric_setup(A)
442+
m, n = size(A)
443+
T = eltype(A)
444+
W = project_isometric(A)
445+
ΔW = randn!(similar(A, T, m, n))
446+
return W, ΔW
447+
end

test/testsuite/chainrules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ for f in
1010
:eig_trunc_no_error, :eigh_trunc_no_error,
1111
:svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals,
1212
:left_polar, :right_polar,
13+
:project_hermitian, :project_antihermitian, :project_isometric,
1314
)
1415
copy_f = Symbol(:cr_copy_, f)
1516
f! = Symbol(f, '!')
@@ -46,6 +47,7 @@ function test_chainrules(T::Type, sz; kwargs...)
4647
test_chainrules_svd(T, sz; kwargs...)
4748
test_chainrules_polar(T, sz; kwargs...)
4849
test_chainrules_orthnull(T, sz; kwargs...)
50+
test_chainrules_projections(T, sz; kwargs...)
4951
end
5052
end
5153

@@ -610,3 +612,58 @@ function test_chainrules_orthnull(
610612
)
611613
end
612614
end
615+
616+
function test_chainrules_projections(
617+
T::Type, sz;
618+
atol::Real = 0, rtol::Real = precision(T),
619+
kwargs...
620+
)
621+
summary_str = testargs_summary(T, sz)
622+
return @testset "Projections Chainrules AD rules $summary_str" begin
623+
A = instantiate_matrix(T, sz)
624+
m, n = size(A)
625+
config = Zygote.ZygoteRuleConfig()
626+
if m == n
627+
alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A)
628+
@testset "project_hermitian" begin
629+
Aₕ, ΔAₕ = ad_project_hermitian_setup(A)
630+
test_rrule(
631+
cr_copy_project_hermitian, A, alg_h NoTangent();
632+
output_tangent = ΔAₕ, atol = atol, rtol = rtol
633+
)
634+
test_rrule(
635+
config, project_hermitian, A;
636+
output_tangent = ΔAₕ,
637+
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
638+
)
639+
end
640+
@testset "project_antihermitian" begin
641+
Aₐ, ΔAₐ = ad_project_antihermitian_setup(A)
642+
test_rrule(
643+
cr_copy_project_antihermitian, A, alg_h NoTangent();
644+
output_tangent = ΔAₐ, atol = atol, rtol = rtol
645+
)
646+
test_rrule(
647+
config, project_antihermitian, A;
648+
output_tangent = ΔAₐ,
649+
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
650+
)
651+
end
652+
end
653+
if m > n
654+
@testset "project_isometric" begin
655+
W, ΔW = ad_project_isometric_setup(A)
656+
alg_iso = MatrixAlgebraKit.default_polar_algorithm(A)
657+
test_rrule(
658+
cr_copy_project_isometric, A, alg_iso NoTangent();
659+
output_tangent = ΔW, atol = atol, rtol = rtol
660+
)
661+
test_rrule(
662+
config, project_isometric, A;
663+
output_tangent = ΔW,
664+
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
665+
)
666+
end
667+
end
668+
end
669+
end

test/testsuite/enzyme.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ function test_enzyme(T::Type, sz; kwargs...)
105105
test_enzyme_polar(T, sz; kwargs...)
106106
test_enzyme_orthnull(T, sz; kwargs...)
107107
end
108+
test_enzyme_projections(T, sz; kwargs...)
108109
end
109110
end
110111

@@ -462,3 +463,41 @@ function test_enzyme_orthnull(
462463
end
463464
end
464465
end
466+
467+
function test_enzyme_projections(
468+
T::Type, sz;
469+
atol::Real = 0, rtol::Real = precision(T),
470+
kwargs...
471+
)
472+
summary_str = testargs_summary(T, sz)
473+
return @testset "Projections Enzyme AD rules $summary_str" begin
474+
A = instantiate_matrix(T, sz)
475+
m, n = size(A)
476+
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
477+
if m == n
478+
@testset "project_hermitian" begin
479+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
480+
Aₕ, ΔAₕ = ad_project_hermitian_setup(A)
481+
eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm)
482+
is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ)
483+
end
484+
end
485+
@testset "project_antihermitian" begin
486+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
487+
Aₐ, ΔAₐ = ad_project_antihermitian_setup(A)
488+
eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm)
489+
is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ)
490+
end
491+
end
492+
end
493+
if m > n
494+
@testset "project_isometric" begin
495+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
496+
W, ΔW = ad_project_isometric_setup(A)
497+
eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm)
498+
is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW)
499+
end
500+
end
501+
end
502+
end
503+
end

0 commit comments

Comments
 (0)