Skip to content

Commit 62dd5ca

Browse files
committed
simplify implementations
1 parent 5474c5e commit 62dd5ca

2 files changed

Lines changed: 20 additions & 82 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -274,46 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
274274
return PWᴴ, right_polar_pullback
275275
end
276276

277-
function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg)
278-
Ac = copy_input(project_hermitian, A)
279-
Aₕ = project_hermitian!(Ac, Aₕ, alg)
277+
function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg)
278+
Aₕ = project_hermitian(A, alg)
280279
function project_hermitian_pullback(ΔAₕ)
281280
ΔA = project_hermitian(unthunk(ΔAₕ))
282-
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
283-
end
284-
function project_hermitian_pullback(::ZeroTangent)
285-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
281+
return NoTangent(), ΔA, NoTangent()
286282
end
287283
return Aₕ, project_hermitian_pullback
288284
end
289285

290-
function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg)
291-
Ac = copy_input(project_antihermitian, A)
292-
Aₐ = project_antihermitian!(Ac, Aₐ, alg)
286+
function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg)
287+
Aₐ = project_antihermitian(A, alg)
293288
function project_antihermitian_pullback(ΔAₐ)
294289
ΔA = project_antihermitian(unthunk(ΔAₐ))
295-
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
296-
end
297-
function project_antihermitian_pullback(::ZeroTangent)
298-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
290+
return NoTangent(), ΔA, NoTangent()
299291
end
300292
return Aₐ, project_antihermitian_pullback
301293
end
302294

303-
function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg)
304-
Ac = copy_input(project_isometric, A)
305-
# Compute the full polar decomposition to cache P for the pullback
306-
WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg)
307-
W_out = copy!(W, WP[1])
308-
function project_isometric_pullback(ΔW)
309-
ΔA = zero(A)
310-
MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing))
311-
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
312-
end
313-
function project_isometric_pullback(::ZeroTangent)
314-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
315-
end
316-
return W_out, project_isometric_pullback
317-
end
318-
319295
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -779,82 +779,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
779779
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
780780
end
781781

782-
# single-output projections: project_hermitian!, project_antihermitian!
783782
# single-output projections: project_hermitian!, project_antihermitian!
784783
for (f!, f, adj) in (
785784
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
786785
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
787786
)
788787
@eval begin
789-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
788+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
790789
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
791790
A, dA = arrayify(A_dA)
792-
Ac = copy(A)
793-
arg, darg = arrayify(arg_darg)
791+
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
794792
argc = copy(arg)
795-
$f!(A, arg, Mooncake.primal(alg_dalg))
793+
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
794+
796795
function $adj(::NoRData)
797-
copy!(A, Ac)
798796
dA .+= $f(darg)
797+
dA === darg || zero!(darg)
799798
copy!(arg, argc)
800-
zero!(darg)
801799
return NoRData(), NoRData(), NoRData(), NoRData()
802800
end
803801
return arg_darg, $adj
804802
end
805-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
803+
804+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
806805
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
807806
A, dA = arrayify(A_dA)
808807
output = $f(A, Mooncake.primal(alg_dalg))
809-
output_codual = CoDual(output, Mooncake.zero_tangent(output))
808+
output_doutput = Mooncake.zero_fcodual(output)
809+
810+
doutput = last(arrayify(output_doutput))
810811
function $adj(::NoRData)
811-
arg, darg = arrayify(output_codual)
812-
dA .+= $f(darg)
813-
zero!(darg)
814-
return NoRData(), NoRData(), NoRData()
812+
# TODO: need accumulating projection to avoid intermediate here
813+
dA .+= $f(doutput)
814+
return ntuple(Returns(NoRData(), 3))
815815
end
816+
816817
return output_codual, $adj
817818
end
818819
end
819820
end
820821

821-
# project_isometric! needs special handling: compute full polar decomposition
822-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
823-
function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
824-
A, dA = arrayify(A_dA)
825-
W, dW = arrayify(W_dW)
826-
Ac = copy(A)
827-
Wc = copy(W)
828-
# Compute the full polar decomposition for the pullback
829-
m, n = size(A)
830-
P = similar(A, n, n)
831-
WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg))
832-
copy!(W, WP[1])
833-
function project_isometric_adjoint(::NoRData)
834-
copy!(A, Ac)
835-
left_polar_pullback!(dA, A, WP, (dW, nothing))
836-
copy!(W, Wc)
837-
zero!(dW)
838-
return NoRData(), NoRData(), NoRData(), NoRData()
839-
end
840-
return W_dW, project_isometric_adjoint
841-
end
842-
843-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm}
844-
function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
845-
A, dA = arrayify(A_dA)
846-
alg = Mooncake.primal(alg_dalg)
847-
# Compute the full polar decomposition for the pullback
848-
WP = left_polar(A, alg)
849-
W_out = WP[1]
850-
output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out))
851-
function project_isometric_adjoint(::NoRData)
852-
W, dW = arrayify(output_codual)
853-
left_polar_pullback!(dA, A, WP, (dW, nothing))
854-
zero!(dW)
855-
return NoRData(), NoRData(), NoRData()
856-
end
857-
return output_codual, project_isometric_adjoint
858-
end
859-
860822
end

0 commit comments

Comments
 (0)