Skip to content

Commit 287c5fe

Browse files
committed
add specializations svd_trunc(!) for TruncatedAlgorithm
1 parent 6fce366 commit 287c5fe

1 file changed

Lines changed: 132 additions & 0 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13+
using MatrixAlgebraKit: TruncatedAlgorithm
1314
using LinearAlgebra
1415

1516

@@ -437,6 +438,48 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
437438
end
438439
return output_codual, svd_trunc_adjoint
439440
end
441+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
442+
# unpack variables
443+
A, dA = arrayify(A_dA)
444+
USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ))
445+
USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr)
446+
alg = Mooncake.primal(alg_dalg)
447+
448+
# store state prior to primal call
449+
Ac = copy(A)
450+
USVᴴc = copy.(USVᴴ)
451+
452+
# compute primal - capture full USVᴴ and ind
453+
USVᴴ = svd_compact!(A, USVᴴ, alg.alg)
454+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
455+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
456+
457+
# pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
458+
# overwritten in the input!
459+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
460+
461+
# define pullback
462+
local svd_trunc_adjoint
463+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
464+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
465+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
466+
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
467+
468+
# compute pullbacks
469+
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
470+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
471+
zero!.(dUSVᴴ)
472+
473+
# restore state
474+
copy!(A, Ac)
475+
copy!.(USVᴴ, USVᴴc)
476+
477+
return ntuple(Returns(NoRData()), 4)
478+
end
479+
end
480+
481+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
482+
end
440483

441484
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
442485
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
@@ -464,6 +507,33 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
464507
end
465508
return output_codual, svd_trunc_adjoint
466509
end
510+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
511+
# unpack variables
512+
A, dA = arrayify(A_dA)
513+
alg = Mooncake.primal(alg_dalg)
514+
515+
# compute primal - capture full USVᴴ and ind
516+
USVᴴ = svd_compact(A, alg.alg)
517+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
518+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
519+
520+
# pack output
521+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
522+
523+
# define pullback
524+
local svd_trunc_adjoint
525+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
526+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
527+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
528+
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
529+
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
530+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
531+
return ntuple(Returns(NoRData()), 3)
532+
end
533+
end
534+
535+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
536+
end
467537

468538
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
469539
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
@@ -504,6 +574,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
504574
end
505575
return output_codual, svd_trunc_adjoint
506576
end
577+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
578+
# unpack variables
579+
A, dA = arrayify(A_dA)
580+
USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ))
581+
USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr)
582+
alg = Mooncake.primal(alg_dalg)
583+
584+
# store state prior to primal call
585+
Ac = copy(A)
586+
USVᴴc = copy.(USVᴴ)
587+
588+
# compute primal - capture full USVᴴ and ind
589+
USVᴴ = svd_compact!(A, USVᴴ, alg.alg)
590+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
591+
592+
# pack output - note that we allocate new dUSVᴴtrunc because these aren't actually
593+
# overwritten in the input!
594+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc)
595+
596+
# define pullback
597+
local svd_trunc_adjoint
598+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
599+
function svd_trunc_adjoint(::NoRData)
600+
# compute pullbacks
601+
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
602+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
603+
zero!.(dUSVᴴ)
604+
605+
# restore state
606+
copy!(A, Ac)
607+
copy!.(USVᴴ, USVᴴc)
608+
609+
return ntuple(Returns(NoRData()), 4)
610+
end
611+
end
612+
613+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
614+
end
507615

508616
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
509617
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
@@ -530,5 +638,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
530638
end
531639
return output_codual, svd_trunc_adjoint
532640
end
641+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
642+
# unpack variables
643+
A, dA = arrayify(A_dA)
644+
alg = Mooncake.primal(alg_dalg)
645+
646+
# compute primal - capture full USVᴴ and ind
647+
USVᴴ = svd_compact(A, alg.alg)
648+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
649+
650+
# pack output
651+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc)
652+
653+
# define pullback
654+
local svd_trunc_adjoint
655+
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
656+
function svd_trunc_adjoint(::NoRData)
657+
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
658+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
659+
return ntuple(Returns(NoRData()), 3)
660+
end
661+
end
662+
663+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
664+
end
533665

534666
end

0 commit comments

Comments
 (0)