Conversation
|
I haven't reviewed anything yet, but would already like to suggest to use pushforward instead of "pull forward" 😄 . If you make this change, may I request to also name the folder "pushforwards"? (I am not too enthusiastic about random abbreviations like leaving out vowels.) |
|
Ok I had a brief look already, only to conclude that this will be a though PR to review 😸 . If I have time this week, I will try to go through the generic pushforward definitions already. All the Moonzyme specific stuff, I will probably need a bit of an introduction first (or read the respective manuals). |
|
@Jutho would it make more sense to remove the forward mode code entirely from this PR and shunt it off into a new one? |
|
Maybe, as you wish. But then you cannot test it in that PR I assume. Which is not necessary for it to be reviewed, so it could still be helpful. |
|
I do think moving the fwd mode stuff out might make this substantially easier to review, actually. We can always test the forward rules once the (working) reverse ones are in place. |
|
I think I agree, it might be nice to split this up into several separated parts to get things moving easier. I would propose the following:
The rest definitely looks great, do you think it could be reasonable to schedule a meeting to just go over the general Mooncake and/or Enzyme approach as a whole? Obviously I don't want to put this on you if you aren't up for it, and I can try and read through them myself, but if you would be okay it might be nice to get the explanation from someone who already has some experience with it :) |
|
Points 1-3 I 100% agree on. Would also be happy to chat/do a walkthrough of both packages (as much as I can...). |
| ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) | ||
| ltRtmp .= zero(eltype(Rtmp)) | ||
| dR11 .= Rtmp * R11 | ||
| dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 |
There was a problem hiding this comment.
dR11 * invR11 is Rtmp, so that can be simplified.
| dQ, dR = dQR | ||
| dQ1 = view(dQ, 1:m, 1:m1) | ||
| dQ2 = view(dQ, 1:m, m1+1:m2+m1) | ||
| dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) |
There was a problem hiding this comment.
Do you mean m1+m2 < size(dQ, 2) or m1+m2+1 <= size(dQ, 2)?
Also, does just using view(dQ, 1:m, m1+m2+1:size(dQ,2)) always not just work? In the case that m1+m2 == size(dQ, 2), we automatically get a view with size(dQ3, 2)=0, no ?
There was a problem hiding this comment.
This logic is a bit out of date (working on a new cleaner pushforwards PR 😉 ) but yes indeed I think these are correct observations
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
index a80dce6..2d795ed 100644
--- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
+++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -16,43 +16,47 @@ using LinearAlgebra
# two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
- (qr_compact!, qr_pullback!, qr_pushforward!),
- (lq_full!, lq_pullback!, lq_pushforward!),
- (lq_compact!, lq_pullback!, lq_pushforward!),
- (eig_full!, eig_pullback!, eig_pushforward!),
- (left_polar!, left_polar_pullback!, left_polar_pushforward!),
- (right_polar!, right_polar_pullback!, right_polar_pushforward!),
- )
+for (f, pb, pf) in (
+ (qr_full!, qr_pullback!, qr_pushforward!),
+ (qr_compact!, qr_pullback!, qr_pushforward!),
+ (lq_full!, lq_pullback!, lq_pushforward!),
+ (lq_compact!, lq_pullback!, lq_pushforward!),
+ (eig_full!, eig_pullback!, eig_pushforward!),
+ (left_polar!, left_polar_pullback!, left_polar_pushforward!),
+ (right_polar!, right_polar_pullback!, right_polar_pushforward!),
+ )
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_arg = nothing
# form cache if needed
- cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
+ cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
func.val(A.val, arg.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
end
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...) where {RT}
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...
+ ) where {RT}
cache_A, cache_arg = cache
argval = arg.val
- Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂arg = isa(arg, Const) ? nothing : arg.dval
+ Aval = !isnothing(cache_A) ? cache_A : A.val
+ ∂arg = isa(arg, Const) ? nothing : arg.dval
if !isa(A, Const) && !isa(arg, Const)
A.dval .= zero(eltype(Aval))
$pb(A.dval, A.val, argval, ∂arg; kwargs...)
@@ -60,24 +64,25 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
!isa(arg, Const) && make_zero!(arg.dval)
return (nothing, nothing, nothing)
end
- function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = func.val(A.val, arg.val, alg.val; kwargs...)
+ function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = func.val(A.val, arg.val, alg.val; kwargs...)
arg1, arg2 = ret
m, n = size(A.val)
if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
darg1, darg2 = arg.dval
- dA = A.dval
+ dA = A.dval
darg1, darg2 = $pf(dA, A.val, ret, arg.dval)
- dA .= zero(eltype(A.val))
- shadow = (darg1, darg2)
+ dA .= zero(eltype(A.val))
+ shadow = (darg1, darg2)
elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed})
make_zero!(arg.dval)
shadow = arg.dval
@@ -96,54 +101,59 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
end
end
-for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!),
- (lq_null!, lq_null_pullback!, lq_null_pushforward!),
- )
+for (f, pb, pf) in (
+ (qr_null!, qr_null_pullback!, qr_null_pushforward!),
+ (lq_null!, lq_null_pullback!, lq_null_pushforward!),
+ )
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_arg = nothing
# form cache if needed
cache_A = nothing #copy(A.val)
func.val(copy(A.val), arg.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
end
-
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
- rank_atol::Real=tol,
- gauge_atol::Real=tol,
- kwargs...) where {RT}
+
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
+ rank_atol::Real = tol,
+ gauge_atol::Real = tol,
+ kwargs...
+ ) where {RT}
cache_A, cache_arg = cache
- Aval = isnothing(cache_A) ? A.val : cache_A
+ Aval = isnothing(cache_A) ? A.val : cache_A
if !isa(A, Const) && !isa(arg, Const)
A.dval .= zero(eltype(A.val))
$pb(A.dval, A.val, arg.val, arg.dval; kwargs...)
end
return (nothing, nothing, nothing)
end
- function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
ret = func.val(A.val, arg.val, alg.val; kwargs...)
if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
@@ -170,15 +180,16 @@ for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!),
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(svd_compact!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(svd_compact!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
shadow = if EnzymeRules.needs_shadow(config)
svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
else
@@ -196,20 +207,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
# TODO
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(svd_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(svd_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
shadow = if EnzymeRules.needs_shadow(config)
- svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
- else
- nothing
- end
+ svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
+ else
+ nothing
+ end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(ret, shadow)
elseif EnzymeRules.needs_shadow(config)
@@ -222,33 +234,36 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
for f in (:svd_compact!, :svd_full!)
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
# form cache if needed
- cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing
- cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing
+ cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing
+ cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing
func.val(A.val, USVᴴ.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ))
end
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...) where {RT}
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...
+ ) where {RT}
cache_A, cache_USVᴴ = cache
USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val
- ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
+ ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
if !isa(A, Const) && !isa(USVᴴ, Const)
A.dval .= zero(eltype(A.dval))
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴval, ∂USVᴴ; kwargs...)
@@ -261,26 +276,27 @@ for f in (:svd_compact!, :svd_full!)
end
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(svd_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T<:Real}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(svd_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T <: Real}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
- cache_USVᴴ = copy.(USVᴴ.val)
- USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
- primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
+ cache_USVᴴ = copy.(USVᴴ.val)
+ USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
dU, dS, dVᴴ = USVᴴ.dval
- dStrunc = Diagonal(diagview(dS)[ind])
- dUtrunc = dU[:, ind]
+ dStrunc = Diagonal(diagview(dS)[ind])
+ dUtrunc = dU[:, ind]
dVᴴtrunc = dVᴴ[ind, :]
(dUtrunc, dStrunc, dVᴴtrunc)
else
@@ -289,17 +305,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(svd_trunc!)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T<:Real}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(svd_trunc!)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T <: Real}
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
- U, S, Vᴴ = cache_USVᴴ
+ U, S, Vᴴ = cache_USVᴴ
dU, dS, dVᴴ = shadow_USVᴴ
if !isa(A, Const) && !isa(USVᴴ, Const)
A.dval .= zero(eltype(A.val))
@@ -314,21 +332,22 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dmat, V = eigh_full(A.val; kwargs...)
if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diag(∂K)
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diag(∂K)
D.dval .= real.(copy(∂Kdiag))
A.dval .= zero(eltype(A.val))
- shadow = D.dval
+ shadow = D.dval
elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
make_zero!(D.dval)
shadow = D.dval
@@ -340,20 +359,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
elseif EnzymeRules.needs_shadow(config)
return shadow
elseif EnzymeRules.needs_primal(config)
- return Dmat.diag
+ return Dmat.diag
else
return nothing
end
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dmat, V = func.val(A.val, DV.val; kwargs...)
if isa(A, Const) || all(iszero, A.dval)
make_zero!(DV.dval[1])
@@ -361,18 +381,18 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
make_zero!(A.dval)
shadow = (DV.dval[1], DV.dval[2])
else
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diagview(∂K)
- ∂Ddiag = diagview(DV.dval[1])
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diagview(∂K)
+ ∂Ddiag = diagview(DV.dval[1])
∂Ddiag .= real.(∂Kdiag)
- D = diagview(Dmat)
- dDD = transpose(D) .- D
- ∂K ./= dDD
+ D = diagview(Dmat)
+ dDD = transpose(D) .- D
+ ∂K ./= dDD
∂Kdiag .= zero(eltype(V))
mul!(DV.dval[2], V, ∂K, 1, 0)
- shadow = DV.dval[2]
+ shadow = DV.dval[2]
A.dval .= zero(eltype(A.val))
- shadow = (Diagonal(∂Ddiag), DV.dval[2])
+ shadow = (Diagonal(∂Ddiag), DV.dval[2])
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated((Dmat, V), shadow)
@@ -385,24 +405,25 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dval, V = eig_full(A.val, alg.val; kwargs...)
if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diag(∂K)
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diag(∂K)
D.dval .= copy(∂Kdiag)
A.dval .= zero(eltype(A.val))
- shadow = D.dval
+ shadow = D.dval
elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
make_zero!(D.dval)
- shadow = D.dval
+ shadow = D.dval
end
eig_vals!(A.val, zeros(complex(eltype(A.val)), size(A.val, 1)))
D.val .= diagview(Dval)
@@ -411,30 +432,31 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
elseif EnzymeRules.needs_shadow(config)
return shadow
elseif EnzymeRules.needs_primal(config)
- return Dmat.diag
+ return Dmat.diag
else
return nothing
end
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
- cache_DV = copy.(DV.val)
- DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
- primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
+ cache_DV = copy.(DV.val)
+ DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
- dD, dV = DV.dval
+ dD, dV = DV.dval
dDtrunc = Diagonal(diagview(dD)[ind])
dVtrunc = dV[:, ind]
(dDtrunc, dVtrunc)
@@ -444,17 +466,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_trunc!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_trunc!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T}
cache_A, cache_DV, cache_dDVtrunc, ind = cache
- D, V = cache_DV
+ D, V = cache_DV
dD, dV = cache_dDVtrunc
if !isa(A, Const) && !isa(DV, Const)
A.dval .= zero(eltype(A.val))
@@ -469,24 +493,25 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
eig_full!(A.val, DV.val, alg.val.alg)
- cache_DV = copy.(DV.val)
- DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
- primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
+ cache_DV = copy.(DV.val)
+ DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
- dD, dV = DV.dval
+ dD, dV = DV.dval
dDtrunc = Diagonal(diagview(dD)[ind])
dVtrunc = dV[:, ind]
(dDtrunc, dVtrunc)
@@ -496,17 +521,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_trunc!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_trunc!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T}
cache_A, cache_DV, cache_dDVtrunc, ind = cache
- D, V = cache_DV
+ D, V = cache_DV
dD, dV = cache_dDVtrunc
if !isa(A, Const) && !isa(DV, Const)
A.dval .= zero(eltype(A.val))
@@ -521,47 +548,49 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
# form cache if needed
cache_DV = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, DV.val, alg.val; kwargs...)
primal = EnzymeRules.needs_primal(config) ? DV.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? DV.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_DV = cache
- DVval = !isnothing(cache_DV) ? cache_DV : DV.val
- Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂DV = isa(DV, Const) ? nothing : DV.dval
+ DVval = !isnothing(cache_DV) ? cache_DV : DV.val
+ Aval = !isnothing(cache_A) ? cache_A : A.val
+ ∂DV = isa(DV, Const) ? nothing : DV.dval
if !isa(A, Const) && !isa(DV, Const)
- Dmat, V = DVval
+ Dmat, V = DVval
∂Dmat, ∂V = ∂DV
- A.dval .= zero(eltype(Aval))
+ A.dval .= zero(eltype(Aval))
MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DVval, ∂DV; kwargs...)
A.dval .*= 2
diagview(A.dval) ./= 2
for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
if i > j
- A.dval[i, j] = zero(eltype(A.dval))
+ A.dval[i, j] = zero(eltype(A.dval))
end
end
end
@@ -571,40 +600,42 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_D = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, D.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? D.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? D.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
# form cache if needed
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_D = cache
Dval = !isnothing(cache_D) ? cache_D : D.val
Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂D = isa(D, Const) ? nothing : D.dval
+ ∂D = isa(D, Const) ? nothing : D.dval
if !isa(A, Const) && !isa(D, Const)
- _, V = eig_full(Aval, alg.val)
+ _, V = eig_full(Aval, alg.val)
A.dval .= zero(eltype(Aval))
- PΔV = V' \ Diagonal(D.dval)
+ PΔV = V' \ Diagonal(D.dval)
if eltype(A.dval) <: Real
ΔAc = PΔV * V'
A.dval .+= real.(ΔAc)
@@ -618,45 +649,47 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_D = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, D.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? D.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? D.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
# form cache if needed
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_D = cache
Dval = !isnothing(cache_D) ? cache_D : D.val
Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂D = isa(D, Const) ? nothing : D.dval
+ ∂D = isa(D, Const) ? nothing : D.dval
if !isa(A, Const) && !isa(D, Const)
_, V = eigh_full(Aval, alg.val)
- A.dval .= zero(eltype(Aval))
+ A.dval .= zero(eltype(Aval))
mul!(A.dval, V * Diagonal(real(∂D)), V', 1, 0)
A.dval .*= 2
diagview(A.dval) ./= 2
for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
if i > j
- A.dval[i, j] = zero(eltype(A.dval))
+ A.dval[i, j] = zero(eltype(A.dval))
end
end
end
diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
index 1797157..9bc5cae 100644
--- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
+++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
@@ -4,7 +4,7 @@ using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview
-using MatrixAlgebraKit: svd_pushforward!
+using MatrixAlgebraKit: svd_pushforward!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward!
@@ -12,74 +12,76 @@ using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_
using LinearAlgebra
# two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
- (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
- (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
- (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
- )
+for (f, pb, pf, adj) in (
+ (qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
+ (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
+ (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
+ (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
+ )
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- dA .= zero(eltype(A))
- args = Mooncake.primal(args_dargs)
- dargs = Mooncake.tangent(args_dargs)
+ A, dA = arrayify(A_dA)
+ dA .= zero(eltype(A))
+ args = Mooncake.primal(args_dargs)
+ dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(args[1], dargs[1])
arg2, darg2 = arrayify(args[2], dargs[2])
function $adj(::Mooncake.NoRData)
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+ args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
darg1 .= zero(eltype(arg1))
darg2 .= zero(eltype(arg2))
return Mooncake.CoDual(args, dargs), $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- args = Mooncake.primal(args_dargs)
- args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
- dargs = Mooncake.tangent(args_dargs)
- arg1, darg1 = arrayify(args[1], dargs[1])
- arg2, darg2 = arrayify(args[2], dargs[2])
+ A, dA = arrayify(A_dA)
+ args = Mooncake.primal(args_dargs)
+ args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+ dargs = Mooncake.tangent(args_dargs)
+ arg1, darg1 = arrayify(args[1], dargs[1])
+ arg2, darg2 = arrayify(args[2], dargs[2])
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
return Mooncake.Dual(args, dargs)
end
end
end
-for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
- (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
- )
+for (f, f_full, pb, pf, adj) in (
+ (qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
+ (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
+ )
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- Ac = MatrixAlgebraKit.copy_input($f_full, A)
+ A, dA = arrayify(A_dA)
+ Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
- arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
+ arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
function $adj(::Mooncake.NoRData)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
$pb(dA, A, arg, darg; kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- return arg_darg, $adj
+ return arg_darg, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- Ac = MatrixAlgebraKit.copy_input($f_full, A)
+ A, dA = arrayify(A_dA)
+ Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
- arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
+ arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
$pf(dA, A, arg, darg; kwargs...)
- dA .= zero(dA)
+ dA .= zero(dA)
return arg_darg
end
end
@@ -89,33 +91,33 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eig_full(A, alg_dalg.primal; kwargs...)
# update tangent
- tmp = V \ dA
- dD .= diagview(tmp * V)
- dA .= zero(eltype(dA))
+ tmp = V \ dA
+ dD .= diagview(tmp * V)
+ dA .= zero(eltype(dA))
return Mooncake.Dual(nD.diag, dD_)
end
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
- dA .= zero(eltype(dA))
- # update primal
- DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
- V = DV[2]
+ dA .= zero(eltype(dA))
+ # update primal
+ DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
+ V = DV[2]
dD .= zero(eltype(D))
function deig_vals_adjoint(::Mooncake.NoRData)
PΔV = V' \ Diagonal(dD)
@@ -163,30 +165,30 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eigh_full(A, alg_dalg.primal; kwargs...)
# update tangent
- tmp = inv(V) * dA * V
- dD .= real.(diagview(tmp))
- D .= nD.diag
- dA .= zero(eltype(dA))
+ tmp = inv(V) * dA * V
+ dD .= real.(diagview(tmp))
+ D .= nD.diag
+ dA .= zero(eltype(dA))
return D_dD
end
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
- DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
+ DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
function deigh_vals_adjoint(::Mooncake.NoRData)
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -199,60 +201,60 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
- A, dA = arrayify(A_dA)
- USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
- dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
- U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
- S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
+ A, dA = arrayify(A_dA)
+ USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+ dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+ U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
+ S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
- USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
+ USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
function dsvd_adjoint(::Mooncake.NoRData)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
minmn = min(size(A)...)
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
- dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
+ dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
- vU = view(U, :, 1:minmn)
- vS = Diagonal(diagview(S)[1:minmn])
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
- dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
+ vU = view(U, :, 1:minmn)
+ vS = Diagonal(diagview(S)[1:minmn])
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
+ dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- dU .= zero(dU)
- dS .= zero(dS)
+ dU .= zero(dU)
+ dS .= zero(dS)
dVᴴ .= zero(dVᴴ)
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
- dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
- A, dA = arrayify(A_, dA_)
+ USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+ dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
+ A, dA = arrayify(A_, dA_)
$f(A, USVᴴ, alg_dalg.primal; kwargs...)
# update tangents
- U_, S_, Vᴴ_ = USVᴴ
+ U_, S_, Vᴴ_ = USVᴴ
dU_, dS_, dVᴴ_ = dUSVᴴ
- U, dU = arrayify(U_, dU_)
- S, dS = arrayify(S_, dS_)
- Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
- minmn = min(size(A)...)
+ U, dU = arrayify(U_, dU_)
+ S, dS = arrayify(S_, dS_)
+ Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
+ minmn = min(size(A)...)
if ($f == svd_compact!) # compact
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
else # full
- vU = view(U, :, 1:minmn)
- vS = view(S, 1:minmn, 1:minmn)
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
+ vU = view(U, :, 1:minmn)
+ vS = view(S, 1:minmn, 1:minmn)
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...)
end
return USVᴴ_dUSVᴴ
@@ -263,15 +265,15 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- S_ = Mooncake.primal(S_dS)
- dS_ = Mooncake.tangent(S_dS)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ S_ = Mooncake.primal(S_dS)
+ dS_ = Mooncake.tangent(S_dS)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
# update tangent
- S, dS = arrayify(S_, dS_)
+ S, dS = arrayify(S_, dS_)
copyto!(dS, diag(real.(Vᴴ * dA' * U)))
copyto!(S, diagview(nS))
dA .= zero(eltype(dA))
@@ -281,17 +283,17 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- S_ = Mooncake.primal(S_dS)
- dS_ = Mooncake.tangent(S_dS)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ S_ = Mooncake.primal(S_dS)
+ dS_ = Mooncake.tangent(S_dS)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
S, dS = arrayify(S_, dS_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
- S .= diagview(nS)
- dS .= zero(eltype(S))
+ S .= diagview(nS)
+ dS .= zero(eltype(S))
function dsvd_vals_adjoint(::Mooncake.NoRData)
- dA .= U * Diagonal(dS) * Vᴴ
+ dA .= U * Diagonal(dS) * Vᴴ
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return S_dS, dsvd_vals_adjoint
diff --git a/src/common/view.jl b/src/common/view.jl
index 0bc7b9e..c8ae1aa 100644
--- a/src/common/view.jl
+++ b/src/common/view.jl
@@ -1,5 +1,5 @@
# diagind: provided by LinearAlgebra.jl
-diagview(D::Diagonal) = D.diag
+diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))
# triangularind
diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl
index 19a43cb..36506f3 100644
--- a/src/pushforwards/eig.jl
+++ b/src/pushforwards/eig.jl
@@ -1,11 +1,11 @@
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
- D, V = DV
- ΔD, ΔV = ΔDV
- iVΔAV = inv(V) * ΔA * V
+ D, V = DV
+ ΔD, ΔV = ΔDV
+ iVΔAV = inv(V) * ΔA * V
diagview(ΔD) .= diagview(iVΔAV)
- F = 1 ./ (transpose(diagview(D)) .- diagview(D))
+ F = 1 ./ (transpose(diagview(D)) .- diagview(D))
fill!(diagview(F), zero(eltype(F)))
- K̇ = F .* iVΔAV
+ K̇ = F .* iVΔAV
mul!(ΔV, V, K̇, 1, 0)
zero!(ΔA)
return ΔDV
diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl
index 69685b1..5fdfdca 100644
--- a/src/pushforwards/eigh.jl
+++ b/src/pushforwards/eigh.jl
@@ -1,16 +1,16 @@
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
- D, V = DV
- dD, dV = dDV
- tmpV = V \ dA
- ∂K = tmpV * V
- ∂Kdiag = diag(∂K)
- dD.diag .= real.(∂Kdiag)
- dDD = transpose(diagview(D)) .- diagview(D)
- F = one(eltype(dDD)) ./ dDD
+ D, V = DV
+ dD, dV = dDV
+ tmpV = V \ dA
+ ∂K = tmpV * V
+ ∂Kdiag = diag(∂K)
+ dD.diag .= real.(∂Kdiag)
+ dDD = transpose(diagview(D)) .- diagview(D)
+ F = one(eltype(dDD)) ./ dDD
diagview(F) .= zero(eltype(F))
- ∂K .*= F
- ∂V = mul!(tmpV, V, ∂K)
+ ∂K .*= F
+ ∂V = mul!(tmpV, V, ∂K)
copyto!(dV, ∂V)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
return (dD, dV)
end
diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl
index 2d390a5..ed5a72d 100644
--- a/src/pushforwards/lq.jl
+++ b/src/pushforwards/lq.jl
@@ -62,7 +62,7 @@
end=#
function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
- qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
+ return qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
end
-function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end
+function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol) end
diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl
index 2001c41..78ab79d 100644
--- a/src/pushforwards/polar.jl
+++ b/src/pushforwards/polar.jl
@@ -1,23 +1,23 @@
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
- W, P = WP
+ W, P = WP
ΔW, ΔP = ΔWP
- aWdA = adjoint(W) * ΔA
- K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
- L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
- ΔW .= W * K̇ + L̇
- ΔP .= aWdA - K̇*P
+ aWdA = adjoint(W) * ΔA
+ K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
+ L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P)
+ ΔW .= W * K̇ + L̇
+ ΔP .= aWdA - K̇ * P
MatrixAlgebraKit.zero!(ΔA)
return (ΔW, ΔP)
end
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
- P, Wᴴ = PWᴴ
+ P, Wᴴ = PWᴴ
ΔP, ΔWᴴ = ΔPWᴴ
- dAW = ΔA * adjoint(Wᴴ)
- K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
- L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
- ΔWᴴ .= K̇ * Wᴴ + L̇
- ΔP .= dAW - P * K̇
+ dAW = ΔA * adjoint(Wᴴ)
+ K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
+ L̇ = inv(P) * ΔA * (Diagonal(ones(el...*[Comment body truncated]* |
|
Has been superseded |
This adds CPU based reverse rules for all the factorizations (NB below) in MatrixAlgebraKit, using Mooncake or Enzyme. I tried my best to use existing pullbacks.
Some notes:
eigh_trunc!,eig_trunc!, orsvd_trunc!rules due to Can't usetest_reversewithMixedDuplicatedEnzymeAD/Enzyme.jl#2677 andMixedDuplicatedfails inEnzyme.autodiffEnzymeAD/Enzyme.jl#2678. This is because of their mixed return of mutable and immutable objects. I'll work on creating a "bypass" signature for each that passes in a reference to epsilon that should allow us to test directly.