@@ -10,6 +10,7 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13+ using MatrixAlgebraKit: TruncatedAlgorithm
1314using LinearAlgebra
1415
1516
@@ -437,6 +438,48 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
437438 end
438439 return output_codual, svd_trunc_adjoint
439440end
441+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
442+ # unpack variables
443+ A, dA = arrayify (A_dA)
444+ USVᴴ_dUSVᴴ_arr = arrayify .(Mooncake. primal (USVᴴ_dUSVᴴ), Mooncake. tangent (USVᴴ_dUSVᴴ))
445+ USVᴴ, dUSVᴴ = first .(USVᴴ_dUSVᴴ_arr), last .(USVᴴ_dUSVᴴ_arr)
446+ alg = Mooncake. primal (alg_dalg)
447+
448+ # store state prior to primal call
449+ Ac = copy (A)
450+ USVᴴc = copy .(USVᴴ)
451+
452+ # compute primal - capture full USVᴴ and ind
453+ USVᴴ = svd_compact! (A, USVᴴ, alg. alg)
454+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
455+ ϵ = MatrixAlgebraKit. truncation_error (diagview (USVᴴ[2 ]), ind)
456+
457+ # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
458+ # overwritten in the input!
459+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
460+
461+ # define pullback
462+ local svd_trunc_adjoint
463+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
464+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
465+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
466+ @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
467+
468+ # compute pullbacks
469+ svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
470+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
471+ zero! .(dUSVᴴ)
472+
473+ # restore state
474+ copy! (A, Ac)
475+ copy! .(USVᴴ, USVᴴc)
476+
477+ return ntuple (Returns (NoRData ()), 4 )
478+ end
479+ end
480+
481+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
482+ end
440483
441484@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc), Any, MatrixAlgebraKit. AbstractAlgorithm}
442485function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
@@ -464,6 +507,33 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
464507 end
465508 return output_codual, svd_trunc_adjoint
466509end
510+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
511+ # unpack variables
512+ A, dA = arrayify (A_dA)
513+ alg = Mooncake. primal (alg_dalg)
514+
515+ # compute primal - capture full USVᴴ and ind
516+ USVᴴ = svd_compact (A, alg. alg)
517+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
518+ ϵ = MatrixAlgebraKit. truncation_error (diagview (USVᴴ[2 ]), ind)
519+
520+ # pack output
521+ USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
522+
523+ # define pullback
524+ local svd_trunc_adjoint
525+ let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
526+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
527+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
528+ @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
529+ svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
530+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
531+ return ntuple (Returns (NoRData ()), 3 )
532+ end
533+ end
534+
535+ return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
536+ end
467537
468538@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
469539function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
0 commit comments