@@ -779,82 +779,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
779779 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
780780end
781781
782- # single-output projections: project_hermitian!, project_antihermitian!
783782# single-output projections: project_hermitian!, project_antihermitian!
784783for (f!, f, adj) in (
785784 (:project_hermitian! , :project_hermitian , :project_hermitian_adjoint ),
786785 (:project_antihermitian! , :project_antihermitian , :project_antihermitian_adjoint ),
787786 )
788787 @eval begin
789- @is_primitive Mooncake . DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
788+ @is_primitive DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
790789 function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
791790 A, dA = arrayify (A_dA)
792- Ac = copy (A)
793- arg, darg = arrayify (arg_darg)
791+ arg, darg = A_dA === arg_darg ? (A, dA) : arrayify (arg_darg)
794792 argc = copy (arg)
795- $ f! (A, arg, Mooncake. primal (alg_dalg))
793+ arg = $ f! (A, arg, Mooncake. primal (alg_dalg))
794+
796795 function $adj (:: NoRData )
797- copy! (A, Ac)
798796 dA .+ = $ f (darg)
797+ dA === darg || zero! (darg)
799798 copy! (arg, argc)
800- zero! (darg)
801799 return NoRData (), NoRData (), NoRData (), NoRData ()
802800 end
803801 return arg_darg, $ adj
804802 end
805- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
803+
804+ @is_primitive DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
806805 function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
807806 A, dA = arrayify (A_dA)
808807 output = $ f (A, Mooncake. primal (alg_dalg))
809- output_codual = CoDual (output, Mooncake. zero_tangent (output))
808+ output_doutput = Mooncake. zero_fcodual (output)
809+
810+ doutput = last (arrayify (output_doutput))
810811 function $adj (:: NoRData )
811- arg, darg = arrayify (output_codual)
812- dA .+ = $ f (darg)
813- zero! (darg)
814- return NoRData (), NoRData (), NoRData ()
812+ # TODO : need accumulating projection to avoid intermediate here
813+ dA .+ = $ f (doutput)
814+ return ntuple (Returns (NoRData (), 3 ))
815815 end
816+
816817 return output_codual, $ adj
817818 end
818819 end
819820end
820821
821- # project_isometric! needs special handling: compute full polar decomposition
822- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (project_isometric!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
823- function Mooncake. rrule!! (f_df:: CoDual{typeof(project_isometric!)} , A_dA:: CoDual , W_dW:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
824- A, dA = arrayify (A_dA)
825- W, dW = arrayify (W_dW)
826- Ac = copy (A)
827- Wc = copy (W)
828- # Compute the full polar decomposition for the pullback
829- m, n = size (A)
830- P = similar (A, n, n)
831- WP = left_polar! (copy (A), (copy (W), P), Mooncake. primal (alg_dalg))
832- copy! (W, WP[1 ])
833- function project_isometric_adjoint (:: NoRData )
834- copy! (A, Ac)
835- left_polar_pullback! (dA, A, WP, (dW, nothing ))
836- copy! (W, Wc)
837- zero! (dW)
838- return NoRData (), NoRData (), NoRData (), NoRData ()
839- end
840- return W_dW, project_isometric_adjoint
841- end
842-
843- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (project_isometric), Any, MatrixAlgebraKit. AbstractAlgorithm}
844- function Mooncake. rrule!! (f_df:: CoDual{typeof(project_isometric)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
845- A, dA = arrayify (A_dA)
846- alg = Mooncake. primal (alg_dalg)
847- # Compute the full polar decomposition for the pullback
848- WP = left_polar (A, alg)
849- W_out = WP[1 ]
850- output_codual = CoDual (W_out, Mooncake. zero_tangent (W_out))
851- function project_isometric_adjoint (:: NoRData )
852- W, dW = arrayify (output_codual)
853- left_polar_pullback! (dA, A, WP, (dW, nothing ))
854- zero! (dW)
855- return NoRData (), NoRData (), NoRData ()
856- end
857- return output_codual, project_isometric_adjoint
858- end
859-
860822end
0 commit comments