@@ -43,37 +43,6 @@ using LinearAlgebra
4343# other provided input variable.
4444# --------------------------------------------
4545
46- # this rule is necessary for now as without it,
47- # a segfault occurs both on 1.10 and 1.12 -- likely
48- # a deeper internal bug
49- function EnzymeRules. augmented_primal (
50- config:: EnzymeRules.RevConfigWidth{1} ,
51- func:: Const{typeof(copy_input)} ,
52- :: Type{RT} ,
53- f:: Annotation ,
54- A:: Annotation
55- ) where {RT}
56- ret = func. val (f. val, A. val)
57- primal = EnzymeRules. needs_primal (config) ? ret : nothing
58- shadow = EnzymeRules. needs_shadow (config) ? zero (A. dval) : nothing
59- return EnzymeRules. AugmentedReturn (primal, shadow, shadow)
60- end
61-
62- function EnzymeRules. reverse (
63- config:: EnzymeRules.RevConfigWidth{1} ,
64- func:: Const{typeof(copy_input)} ,
65- :: Type{RT} ,
66- cache,
67- f:: Annotation ,
68- A:: Annotation
69- ) where {RT}
70- copy_shadow = cache
71- if ! isa (A, Const) && ! isnothing (copy_shadow)
72- A. dval .+ = copy_shadow
73- end
74- return (nothing , nothing )
75- end
76-
7746# two-argument factorizations like LQ, QR, EIG
7847for (f, pb) in (
7948 (qr_full!, qr_pullback!),
@@ -101,9 +70,11 @@ for (f, pb) in (
10170 # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
10271 # if arg isa Const, ret may still be modified further down the call graph so we should
10372 # copy it to protect ourselves
104- cache_arg = (arg. val != = ret) || EnzymeRules. overwritten (config)[3 ] ? copy .(ret) : nothing
105- dret = if EnzymeRules. needs_shadow (config)
106- (TA == Nothing && TB == Nothing) || isa (arg, Const) ? zero .(ret) : arg. dval
73+ cache_arg = (arg. val != = ret && arg. val[1 ] != = A. val) || EnzymeRules. overwritten (config)[3 ] ? copy .(ret) : nothing
74+ dret = if EnzymeRules. needs_shadow (config) && ((TA == Nothing && TB == Nothing) || isa (arg, Const))
75+ make_zero .(ret)
76+ elseif EnzymeRules. needs_shadow (config)
77+ arg. dval
10778 else
10879 nothing
10980 end
@@ -125,11 +96,24 @@ for (f, pb) in (
12596 # use A (so that whoever does this is forced to handle caching A
12697 # appropriately here)
12798 Aval = nothing
99+ A_is_arg = ! isa (A, Const) && A. dval === arg. dval[1 ]
128100 argval = something (cache_arg, arg. val)
129101 if ! isa (A, Const)
130- $ pb (A. dval, Aval, argval, darg)
102+ if A_is_arg
103+ ΔA = make_zero (A. dval)
104+ $ pb (ΔA, Aval, argval, darg)
105+ A. dval .= ΔA
106+ else
107+ $ pb (A. dval, Aval, argval, darg)
108+ end
109+ end
110+ if ! isa (arg, Const)
111+ if ! A_is_arg
112+ make_zero! (arg. dval)
113+ else
114+ make_zero! (arg. dval[2 ])
115+ end
131116 end
132- ! isa (arg, Const) && make_zero! (arg. dval)
133117 return (nothing , nothing , nothing )
134118 end
135119 end
@@ -356,7 +340,13 @@ for (f, trunc_f, full_f, pb) in (
356340 if ! isa (A, Const)
357341 $ pb (A. dval, Aval, DVval, dDVtrunc, ind)
358342 end
359- ! isa (DV, Const) && make_zero! (DV. dval)
343+ if ! isa (DV, Const)
344+ if ! (A. dval === DV. dval[1 ])
345+ make_zero! (DV. dval)
346+ else
347+ make_zero! (DV. dval[2 ])
348+ end
349+ end
360350 return (nothing , nothing , nothing )
361351 end
362352end
@@ -392,21 +382,30 @@ for (f!, f_full!, pb!) in (
392382 func:: Const{typeof($f!)} ,
393383 :: Type{RT} ,
394384 cache,
395- A:: Annotation ,
385+ A:: Annotation{TA} ,
396386 D:: Annotation ,
397387 alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
398- ) where {RT}
388+ ) where {RT, TA }
399389 cache_D, dD, V = cache
400390 Dval = something (cache_D, D. val)
401391 # A is NOT used in the pullback, so we assign Aval = nothing
402392 # to trigger an error in case the pullback is modified to directly
403393 # use A (so that whoever does this is forced to handle caching A
404394 # appropriately here)
405395 Aval = nothing
396+ A_is_arg = ! isa (A, Const) && TA <: Diagonal && diagview (A. dval) === D. dval
406397 if ! isa (A, Const)
407- $ pb! (A. dval, Aval, (Diagonal (Dval), V), dD)
398+ if A_is_arg
399+ ΔA = make_zero (A. dval)
400+ $ pb! (ΔA, Aval, (Diagonal (Dval), V), dD)
401+ A. dval .= ΔA
402+ else
403+ $ pb! (A. dval, Aval, (Diagonal (Dval), V), dD)
404+ end
405+ end
406+ if ! isa (D, Const) && ! A_is_arg
407+ make_zero! (D. dval)
408408 end
409- ! isa (D, Const) && make_zero! (D. dval)
410409 return (nothing , nothing , nothing )
411410 end
412411 end
@@ -438,10 +437,10 @@ function EnzymeRules.reverse(
438437 func:: Const{typeof(svd_vals!)} ,
439438 :: Type{RT} ,
440439 cache,
441- A:: Annotation ,
440+ A:: Annotation{TA} ,
442441 S:: Annotation ,
443442 alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
444- ) where {RT}
443+ ) where {RT, TA }
445444 cache_S, dS, U, Vᴴ = cache
446445 # A is NOT used in the pullback, so we assign Aval = nothing
447446 # to trigger an error in case the pullback is modified to directly
@@ -452,7 +451,9 @@ function EnzymeRules.reverse(
452451 if ! isa (A, Const)
453452 svd_vals_pullback! (A. dval, Aval, (U, Diagonal (Sval), Vᴴ), dS)
454453 end
455- ! isa (S, Const) && make_zero! (S. dval)
454+ if ! isa (S, Const) && ! (TA <: Diagonal && (diagview (A. dval) === S. dval))
455+ make_zero! (S. dval)
456+ end
456457 return (nothing , nothing , nothing )
457458end
458459
0 commit comments