Skip to content

Commit 08eaf37

Browse files
authored
Fix check of whether to copy arg for Enzyme (#239)
* Force usage of working Enzyme * Fix check for cache_arg
1 parent f93de2c commit 08eaf37

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Aqua = "0.6, 0.7, 0.8"
3131
CUDA = "6"
3232
ChainRulesCore = "1"
3333
ChainRulesTestUtils = "1"
34-
Enzyme = "0.13.131"
34+
Enzyme = "0.13.148"
3535
EnzymeTestUtils = "0.2.5"
3636
GenericLinearAlgebra = "0.3.19, 0.4"
3737
GenericSchur = "0.5.6"

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ for (f, pb) in (
6767
# so we do not need to cache it. This may change if future pullbacks
6868
# depend directly on A!
6969
ret = func.val(A.val, arg.val, alg.val)
70-
# if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
70+
# if arg.val === ret, the annotation must be Duplicated or DuplicatedNoNeed
7171
# if arg isa Const, ret may still be modified further down the call graph so we should
7272
# copy it to protect ourselves
7373
A_is_arg1 = !isa(A, Const) && A.val === arg.val[1]
7474
A_is_arg2 = !isa(A, Const) && A.val === arg.val[2]
7575
A_is_arg = A_is_arg1 || A_is_arg2
76-
cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
76+
cache_arg = arg.val !== ret || A_is_arg || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
7777
dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const))
7878
make_zero.(ret)
7979
elseif EnzymeRules.needs_shadow(config)

0 commit comments

Comments
 (0)