@@ -10,6 +10,7 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13+ using MatrixAlgebraKit: TruncatedAlgorithm
1314using LinearAlgebra
1415
1516
@@ -437,6 +438,48 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
437438 end
438439 return output_codual, svd_trunc_adjoint
439440end
441+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
442+ # unpack variables
443+ A, dA = arrayify (A_dA)
444+ USVᴴ_dUSVᴴ_arr = arrayify .(Mooncake. primal (USVᴴ_dUSVᴴ), Mooncake. tangent (USVᴴ_dUSVᴴ))
445+ USVᴴ, dUSVᴴ = first .(USVᴴ_dUSVᴴ_arr), last .(USVᴴ_dUSVᴴ_arr)
446+ alg = Mooncake. primal (alg_dalg)
447+
448+ # store state prior to primal call
449+ Ac = copy (A)
450+ USVᴴc = copy .(USVᴴ)
451+
452+ # compute primal - capture full USVᴴ and ind
453+ USVᴴ = svd_compact! (A, USVᴴ, alg. alg)
454+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
455+ ϵ = MatrixAlgebraKit. truncation_error (diagview (USVᴴ[2 ]), ind)
456+
457+ # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
458+ # overwritten in the input!
459+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
460+
461+ # define pullback
462+ local svd_trunc_adjoint
463+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
464+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
465+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
466+ @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
467+
468+ # compute pullbacks
469+ svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
470+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
471+ zero! .(dUSVᴴ)
472+
473+ # restore state
474+ copy! (A, Ac)
475+ copy! .(USVᴴ, USVᴴc)
476+
477+ return ntuple (Returns (NoRData ()), 4 )
478+ end
479+ end
480+
481+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
482+ end
440483
441484@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc), Any, MatrixAlgebraKit. AbstractAlgorithm}
442485function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
@@ -464,6 +507,33 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
464507 end
465508 return output_codual, svd_trunc_adjoint
466509end
510+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
511+ # unpack variables
512+ A, dA = arrayify (A_dA)
513+ alg = Mooncake. primal (alg_dalg)
514+
515+ # compute primal - capture full USVᴴ and ind
516+ USVᴴ = svd_compact (A, alg. alg)
517+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
518+ ϵ = MatrixAlgebraKit. truncation_error (diagview (USVᴴ[2 ]), ind)
519+
520+ # pack output
521+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
522+
523+ # define pullback
524+ local svd_trunc_adjoint
525+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
526+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
527+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
528+ @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
529+ svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
530+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
531+ return ntuple (Returns (NoRData ()), 3 )
532+ end
533+ end
534+
535+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
536+ end
467537
468538@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
469539function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
@@ -504,6 +574,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
504574 end
505575 return output_codual, svd_trunc_adjoint
506576end
577+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
578+ # unpack variables
579+ A, dA = arrayify (A_dA)
580+ USVᴴ_dUSVᴴ_arr = arrayify .(Mooncake. primal (USVᴴ_dUSVᴴ), Mooncake. tangent (USVᴴ_dUSVᴴ))
581+ USVᴴ, dUSVᴴ = first .(USVᴴ_dUSVᴴ_arr), last .(USVᴴ_dUSVᴴ_arr)
582+ alg = Mooncake. primal (alg_dalg)
583+
584+ # store state prior to primal call
585+ Ac = copy (A)
586+ USVᴴc = copy .(USVᴴ)
587+
588+ # compute primal - capture full USVᴴ and ind
589+ USVᴴ = svd_compact! (A, USVᴴ, alg. alg)
590+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
591+
592+ # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
593+ # overwritten in the input!
594+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual (USVᴴtrunc)
595+
596+ # define pullback
597+ local svd_trunc_adjoint
598+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
599+ function svd_trunc_adjoint (:: NoRData )
600+ # compute pullbacks
601+ svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
602+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
603+ zero! .(dUSVᴴ)
604+
605+ # restore state
606+ copy! (A, Ac)
607+ copy! .(USVᴴ, USVᴴc)
608+
609+ return ntuple (Returns (NoRData ()), 4 )
610+ end
611+ end
612+
613+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
614+ end
507615
508616@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error), Any, MatrixAlgebraKit. AbstractAlgorithm}
509617function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
@@ -530,5 +638,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
530638 end
531639 return output_codual, svd_trunc_adjoint
532640end
641+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
642+ # unpack variables
643+ A, dA = arrayify (A_dA)
644+ alg = Mooncake. primal (alg_dalg)
645+
646+ # compute primal - capture full USVᴴ and ind
647+ USVᴴ = svd_compact (A, alg. alg)
648+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
649+
650+ # pack output
651+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual (USVᴴtrunc)
652+
653+ # define pullback
654+ local svd_trunc_adjoint
655+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
656+ function svd_trunc_adjoint (:: NoRData )
657+ svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
658+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
659+ return ntuple (Returns (NoRData ()), 3 )
660+ end
661+ end
662+
663+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
664+ end
533665
534666end
0 commit comments