Skip to content

Commit 34643f8

Browse files
committed
remove letblocks
1 parent 6b352e9 commit 34643f8

1 file changed

Lines changed: 66 additions & 83 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 66 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1313
using MatrixAlgebraKit: TruncatedAlgorithm
1414
using LinearAlgebra
1515

16-
1716
Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent
1817

1918
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
@@ -235,21 +234,19 @@ for f in (:eig, :eigh)
235234
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
236235

237236
# define pullback
238-
local $f_adjoint!
239-
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
240-
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
241-
_warn_pullback_truncerror(dϵ)
237+
dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
238+
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
239+
_warn_pullback_truncerror(dϵ)
242240

243-
# compute pullbacks
244-
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
245-
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
241+
# compute pullbacks
242+
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
243+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
246244

247-
# restore state
248-
copy!(A, Ac)
249-
copy!.(DV, DVc)
245+
# restore state
246+
copy!(A, Ac)
247+
copy!.(DV, DVc)
250248

251-
return ntuple(Returns(NoRData()), 4)
252-
end
249+
return ntuple(Returns(NoRData()), 4)
253250
end
254251

255252
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -291,14 +288,12 @@ for f in (:eig, :eigh)
291288
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
292289

293290
# define pullback
294-
local $f_adjoint!
295-
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
296-
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
297-
_warn_pullback_truncerror(dϵ)
298-
$f_pullback!(dA, A, DV, dDVtrunc, ind)
299-
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
300-
return ntuple(Returns(NoRData()), 3)
301-
end
291+
dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
292+
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
293+
_warn_pullback_truncerror(dϵ)
294+
$f_pullback!(dA, A, DV, dDVtrunc, ind)
295+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
296+
return ntuple(Returns(NoRData()), 3)
302297
end
303298

304299
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -353,19 +348,17 @@ for f in (:eig, :eigh)
353348
DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc)
354349

355350
# define pullback
356-
local $f_adjoint!
357-
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
358-
function $f_adjoint!(::NoRData)
359-
# compute pullbacks
360-
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
361-
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
362-
363-
# restore state
364-
copy!(A, Ac)
365-
copy!.(DV, DVc)
366-
367-
return ntuple(Returns(NoRData()), 4)
368-
end
351+
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
352+
function $f_adjoint!(::NoRData)
353+
# compute pullbacks
354+
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
355+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
356+
357+
# restore state
358+
copy!(A, Ac)
359+
copy!.(DV, DVc)
360+
361+
return ntuple(Returns(NoRData()), 4)
369362
end
370363

371364
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -405,13 +398,11 @@ for f in (:eig, :eigh)
405398
DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc)
406399

407400
# define pullback
408-
local $f_adjoint!
409-
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
410-
function $f_adjoint!(::NoRData)
411-
$f_pullback!(dA, A, DV, dDVtrunc, ind)
412-
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
413-
return ntuple(Returns(NoRData()), 3)
414-
end
401+
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
402+
function $f_adjoint!(::NoRData)
403+
$f_pullback!(dA, A, DV, dDVtrunc, ind)
404+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
405+
return ntuple(Returns(NoRData()), 3)
415406
end
416407

417408
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -594,22 +585,20 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
594585
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
595586

596587
# define pullback
597-
local svd_trunc_adjoint
598-
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
599-
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
600-
_warn_pullback_truncerror(dϵ)
588+
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
589+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
590+
_warn_pullback_truncerror(dϵ)
601591

602-
# compute pullbacks
603-
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
604-
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
605-
zero!.(dUSVᴴ)
592+
# compute pullbacks
593+
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
594+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
595+
zero!.(dUSVᴴ)
606596

607-
# restore state
608-
copy!(A, Ac)
609-
copy!.(USVᴴ, USVᴴc)
597+
# restore state
598+
copy!(A, Ac)
599+
copy!.(USVᴴ, USVᴴc)
610600

611-
return ntuple(Returns(NoRData()), 4)
612-
end
601+
return ntuple(Returns(NoRData()), 4)
613602
end
614603

615604
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -655,14 +644,12 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655644
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
656645

657646
# define pullback
658-
local svd_trunc_adjoint
659-
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
660-
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
661-
_warn_pullback_truncerror(dϵ)
662-
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
663-
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
664-
return ntuple(Returns(NoRData()), 3)
665-
end
647+
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
648+
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
649+
_warn_pullback_truncerror(dϵ)
650+
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
651+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
652+
return ntuple(Returns(NoRData()), 3)
666653
end
667654

668655
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -727,20 +714,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
727714
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc)
728715

729716
# define pullback
730-
local svd_trunc_adjoint
731-
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
732-
function svd_trunc_adjoint(::NoRData)
733-
# compute pullbacks
734-
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
735-
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
736-
zero!.(dUSVᴴ)
737-
738-
# restore state
739-
copy!(A, Ac)
740-
copy!.(USVᴴ, USVᴴc)
741-
742-
return ntuple(Returns(NoRData()), 4)
743-
end
717+
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
718+
function svd_trunc_adjoint(::NoRData)
719+
# compute pullbacks
720+
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
721+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
722+
zero!.(dUSVᴴ)
723+
724+
# restore state
725+
copy!(A, Ac)
726+
copy!.(USVᴴ, USVᴴc)
727+
728+
return ntuple(Returns(NoRData()), 4)
744729
end
745730

746731
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -784,13 +769,11 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
784769
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc)
785770

786771
# define pullback
787-
local svd_trunc_adjoint
788-
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
789-
function svd_trunc_adjoint(::NoRData)
790-
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
791-
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
792-
return ntuple(Returns(NoRData()), 3)
793-
end
772+
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
773+
function svd_trunc_adjoint(::NoRData)
774+
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
775+
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
776+
return ntuple(Returns(NoRData()), 3)
794777
end
795778

796779
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint

0 commit comments

Comments
 (0)