Skip to content

Commit fd84210

Browse files
committed
small reorganization
1 parent 7f6f16d commit fd84210

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,23 @@ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1313
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm
1414
using LinearAlgebra
1515

16-
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
1716

17+
# Utility
18+
# -------
19+
# convenience helper for marking DefaultCtx ReverseMode signature as primitive
1820
macro is_rev_primitive(sig)
1921
return esc(:(Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode $sig))
2022
end
2123

2224
# return n copies of NoRData()
2325
@inline n_NoRData(n) = ntuple(Returns(NoRData()), n)
2426

27+
# No derivatives
28+
# --------------
29+
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
30+
Mooncake.@zero_derivative DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
31+
32+
2533
@is_rev_primitive Tuple{typeof(copy_input), Any, Any}
2634
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2735
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
@@ -34,7 +42,6 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
3442
return Ac_dAc, copy_input_pb
3543
end
3644

37-
Mooncake.@zero_derivative DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
3845
# two-argument in-place factorizations like LQ, QR, EIG
3946
for (f!, f, pb, adj) in (
4047
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

0 commit comments

Comments
 (0)