Skip to content

Commit f1385c7

Browse files
author
Katharine Hyatt
committed
Small fixes
1 parent b79f25c commit f1385c7

1 file changed

Lines changed: 10 additions & 12 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,11 @@ for (f!, f_full!, pb!) in (
379379
ret = TD == Nothing ? diagview(nD) : copy!(D.val, diagview(nD))
380380
cache_D = (D.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
381381
primal = EnzymeRules.needs_primal(config) ? ret : nothing
382-
dret = if EnzymeRules.needs_shadow(config)
383-
TD == Nothing || isa(D, Const) ? zero(ret) : D.dval
384-
else
385-
nothing
386-
end
387-
return EnzymeRules.AugmentedReturn(primal, dret, (cache_D, dret, V))
382+
# on 1.10, Enzyme can get confused about whether it needs the shadow
383+
# create dret no matter what to account for this
384+
dret = TD == Nothing || isa(D, Const) ? zero(ret) : D.dval
385+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
386+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_D, dret, V))
388387
end
389388
function EnzymeRules.reverse(
390389
config::EnzymeRules.RevConfigWidth{1},
@@ -426,12 +425,11 @@ function EnzymeRules.augmented_primal(
426425
ret = TS == Nothing ? diagview(nS) : copy!(S.val, diagview(nS))
427426
cache_S = (S.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
428427
primal = EnzymeRules.needs_primal(config) ? ret : nothing
429-
dret = if EnzymeRules.needs_shadow(config)
430-
TS == Nothing || isa(S, Const):zero(ret):S.dval
431-
else
432-
nothing
433-
end
434-
return EnzymeRules.AugmentedReturn(primal, dret, (cache_S, dret, U, Vᴴ))
428+
# on 1.10, Enzyme can get confused about whether it needs the shadow
429+
# create dret no matter what to account for this
430+
dret = TS == Nothing || isa(S, Const) ? zero(ret) : S.dval
431+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
432+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_S, dret, U, Vᴴ))
435433
end
436434
function EnzymeRules.reverse(
437435
config::EnzymeRules.RevConfigWidth{1},

0 commit comments

Comments
 (0)