Skip to content

Commit 40d17f3

Browse files
kshyattJutho
andauthored
Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 9960d29 commit 40d17f3

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,15 @@ function EnzymeRules.augmented_primal(
430430
cache_S = (S.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
431431
primal = EnzymeRules.needs_primal(config) ? ret : nothing
432432
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
433-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_S, dret, U, Vᴴ))
433+
ret = TS == Nothing ? diagview(nS) : copy!(S.val, diagview(nS))
434+
cache_S = (S.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
435+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
436+
dret = if EnzymeRules.needs_shadow(config)
437+
TS == Nothing || isa(S, Const) : zero(ret) : S.dval
438+
else
439+
nothing
440+
end
441+
return EnzymeRules.AugmentedReturn(primal, dret, (cache_S, dret, U, Vᴴ))
434442
end
435443
function EnzymeRules.reverse(
436444
config::EnzymeRules.RevConfigWidth{1},

0 commit comments

Comments
 (0)