Skip to content

Commit 9960d29

Browse files
kshyattJutho
andauthored
Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent c6b3062 commit 9960d29

1 file changed

Lines changed: 7 additions & 10 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,15 @@ for (f!, f_full!, pb!) in (
372372
# so we do not need to cache it. This may change if future pullbacks
373373
# depend directly on A!
374374
nD, V = $f_full!(A.val, alg.val)
375-
if TD == Nothing || isa(D, Const)
376-
ret = diagview(nD)
377-
dret = zero(ret)
378-
else
379-
ret = D.val
380-
copy!(ret, diagview(nD))
381-
dret = D.dval
382-
end
375+
ret = TD == Nothing ? diagview(nD) : copy!(D.val, diagview(nD))
383376
cache_D = (D.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
384377
primal = EnzymeRules.needs_primal(config) ? ret : nothing
385-
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
386-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_D, dret, V))
378+
dret = if EnzymeRules.needs_shadow(config)
379+
TD == Nothing || isa(D, Const) ? zero(ret) : D.dval
380+
else
381+
nothing
382+
end
383+
return EnzymeRules.AugmentedReturn(primal, dret, (cache_D, dret, V))
387384
end
388385
function EnzymeRules.reverse(
389386
config::EnzymeRules.RevConfigWidth{1},

0 commit comments

Comments
 (0)