Skip to content

Commit c544809

Browse files
author
Katharine Hyatt
committed
Small fixes
1 parent b79f25c commit c544809

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ for (f!, f_full!, pb!) in (
379379
ret = TD == Nothing ? diagview(nD) : copy!(D.val, diagview(nD))
380380
cache_D = (D.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
381381
primal = EnzymeRules.needs_primal(config) ? ret : nothing
382+
# on 1.10, Enzyme can get confused about whether it needs the shadow
382383
dret = if EnzymeRules.needs_shadow(config)
383384
TD == Nothing || isa(D, Const) ? zero(ret) : D.dval
384385
else
@@ -397,13 +398,16 @@ for (f!, f_full!, pb!) in (
397398
) where {RT}
398399
cache_D, dD, V = cache
399400
Dval = something(cache_D, D.val)
401+
# on 1.10, Enzyme can get confused about whether it needs the shadow
402+
# replace the dret with the arg.dval in the case it's nothing
403+
dDval = something(dD, D.dval)
400404
# A is NOT used in the pullback, so we assign Aval = nothing
401405
# to trigger an error in case the pullback is modified to directly
402406
# use A (so that whoever does this is forced to handle caching A
403407
# appropriately here)
404408
Aval = nothing
405409
if !isa(A, Const)
406-
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
410+
$pb!(A.dval, Aval, (Diagonal(Dval), V), dDval)
407411
end
408412
!isa(D, Const) && make_zero!(D.dval)
409413
return (nothing, nothing, nothing)
@@ -426,8 +430,9 @@ function EnzymeRules.augmented_primal(
426430
ret = TS == Nothing ? diagview(nS) : copy!(S.val, diagview(nS))
427431
cache_S = (S.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
428432
primal = EnzymeRules.needs_primal(config) ? ret : nothing
433+
# on 1.10, Enzyme can get confused about whether it needs the shadow
429434
dret = if EnzymeRules.needs_shadow(config)
430-
TS == Nothing || isa(S, Const):zero(ret):S.dval
435+
TS == Nothing || isa(S, Const) ? zero(ret) : S.dval
431436
else
432437
nothing
433438
end
@@ -449,8 +454,11 @@ function EnzymeRules.reverse(
449454
# appropriately here)
450455
Aval = nothing
451456
Sval = something(cache_S, S.val)
457+
# on 1.10, Enzyme can get confused about whether it needs the shadow
458+
# replace the dret with the arg.dval in the case it's nothing
459+
dSval = something(dS, S.dval)
452460
if !isa(A, Const)
453-
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
461+
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dSval)
454462
end
455463
!isa(S, Const) && make_zero!(S.dval)
456464
return (nothing, nothing, nothing)

0 commit comments

Comments
 (0)