Skip to content

Commit c6b3062

Browse files
kshyattJutho
andauthored
Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent ff34d9e commit c6b3062

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,12 @@ for (f, trunc_f, full_f, pb) in (
324324
cache_DV = (DV.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
325325
DV′, ind = truncate($trunc_f, ret, alg.val.trunc)
326326
primal = EnzymeRules.needs_primal(config) ? DV′ : nothing
327-
dret = (Diagonal(zero(diagview(DV′[1]))), zero(DV′[2]))
328-
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
329-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_DV, dret, ind))
327+
dret = if EnzymeRules.needs_shadow(config)
328+
(Diagonal(zero(diagview(DV′[1]))), zero(DV′[2]))
329+
else
330+
nothing
331+
end
332+
return EnzymeRules.AugmentedReturn(primal, dret, (cache_DV, dret, ind))
330333
end
331334
@eval function EnzymeRules.reverse(
332335
config::EnzymeRules.RevConfigWidth{1},

0 commit comments

Comments
 (0)