Skip to content

Commit 997987d

Browse files
committed
add specializations svd_trunc(!) for TruncatedAlgorithm
1 parent 8de569f commit 997987d

1 file changed

Lines changed: 70 additions & 0 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13+
using MatrixAlgebraKit: TruncatedAlgorithm
1314
using LinearAlgebra
1415

1516

@@ -437,6 +438,48 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
437438
end
438439
return output_codual, svd_trunc_adjoint
439440
end
441+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
442+
# unpack variables
443+
A, dA = arrayify(A_dA)
444+
USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ))
445+
USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr)
446+
alg = Mooncake.primal(alg_dalg)
447+
448+
# store state prior to primal call
449+
Ac = copy(A)
450+
USVᴴc = copy.(USVᴴ)
451+
452+
# compute primal - capture full USVᴴ and ind
453+
USVᴴ = svd_compact!(A, USVᴴ, alg.alg)
454+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
455+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
456+
457+
# pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
458+
# overwritten in the input!
459+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
460+
461+
# define pullback
462+
local svd_trunc_adjoint
463+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
464+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
465+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
466+
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
467+
468+
# compute pullbacks
469+
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
470+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
471+
zero!.(dUSVᴴ)
472+
473+
# restore state
474+
copy!(A, Ac)
475+
copy!.(USVᴴ, USVᴴc)
476+
477+
return ntuple(Returns(NoRData()), 4)
478+
end
479+
end
480+
481+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
482+
end
440483

441484
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
442485
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
@@ -464,6 +507,33 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
464507
end
465508
return output_codual, svd_trunc_adjoint
466509
end
510+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
511+
# unpack variables
512+
A, dA = arrayify(A_dA)
513+
alg = Mooncake.primal(alg_dalg)
514+
515+
# compute primal - capture full USVᴴ and ind
516+
USVᴴ = svd_compact(A, alg.alg)
517+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
518+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
519+
520+
# pack output
521+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
522+
523+
# define pullback
524+
local svd_trunc_adjoint
525+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
526+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
527+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
528+
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
529+
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
530+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
531+
return ntuple(Returns(NoRData()), 3)
532+
end
533+
end
534+
535+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
536+
end
467537

468538
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
469539
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)

0 commit comments

Comments
 (0)