@@ -43,37 +43,6 @@ using LinearAlgebra
4343# other provided input variable.
4444# --------------------------------------------
4545
46- # this rule is necessary for now as without it,
47- # a segfault occurs both on 1.10 and 1.12 -- likely
48- # a deeper internal bug
49- function EnzymeRules. augmented_primal (
50- config:: EnzymeRules.RevConfigWidth{1} ,
51- func:: Const{typeof(copy_input)} ,
52- :: Type{RT} ,
53- f:: Annotation ,
54- A:: Annotation
55- ) where {RT}
56- ret = func. val (f. val, A. val)
57- primal = EnzymeRules. needs_primal (config) ? ret : nothing
58- shadow = EnzymeRules. needs_shadow (config) ? zero (A. dval) : nothing
59- return EnzymeRules. AugmentedReturn (primal, shadow, shadow)
60- end
61-
62- function EnzymeRules. reverse (
63- config:: EnzymeRules.RevConfigWidth{1} ,
64- func:: Const{typeof(copy_input)} ,
65- :: Type{RT} ,
66- cache,
67- f:: Annotation ,
68- A:: Annotation
69- ) where {RT}
70- copy_shadow = cache
71- if ! isa (A, Const) && ! isnothing (copy_shadow)
72- A. dval .+ = copy_shadow
73- end
74- return (nothing , nothing )
75- end
76-
7746# two-argument factorizations like LQ, QR, EIG
7847for (f, pb) in (
7948 (qr_full!, qr_pullback!),
@@ -101,9 +70,14 @@ for (f, pb) in (
10170 # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
10271 # if arg isa Const, ret may still be modified further down the call graph so we should
10372 # copy it to protect ourselves
104- cache_arg = (arg. val != = ret) || EnzymeRules. overwritten (config)[3 ] ? copy .(ret) : nothing
105- dret = if EnzymeRules. needs_shadow (config)
106- (TA == Nothing && TB == Nothing) || isa (arg, Const) ? zero .(ret) : arg. dval
73+ A_is_arg1 = ! isa (A, Const) && A. val === arg. val[1 ]
74+ A_is_arg2 = ! isa (A, Const) && A. val === arg. val[2 ]
75+ A_is_arg = A_is_arg1 || A_is_arg2
76+ cache_arg = (arg. val != = ret && ! A_is_arg) || EnzymeRules. overwritten (config)[3 ] ? copy .(ret) : nothing
77+ dret = if EnzymeRules. needs_shadow (config) && ((TA == Nothing && TB == Nothing) || isa (arg, Const))
78+ make_zero .(ret)
79+ elseif EnzymeRules. needs_shadow (config)
80+ arg. dval
10781 else
10882 nothing
10983 end
@@ -125,11 +99,19 @@ for (f, pb) in (
12599 # use A (so that whoever does this is forced to handle caching A
126100 # appropriately here)
127101 Aval = nothing
102+ A_is_arg1 = ! isa (A, Const) && A. dval === arg. dval[1 ]
103+ A_is_arg2 = ! isa (A, Const) && A. dval === arg. dval[2 ]
104+ A_is_arg = A_is_arg1 || A_is_arg2
128105 argval = something (cache_arg, arg. val)
129106 if ! isa (A, Const)
130- $ pb (A. dval, Aval, argval, darg)
107+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
108+ $ pb (ΔA, Aval, argval, darg)
109+ A_is_arg && (A. dval .= ΔA)
110+ end
111+ if ! isa (arg, Const)
112+ A_is_arg1 || make_zero! (arg. dval[1 ])
113+ A_is_arg2 || make_zero! (arg. dval[2 ])
131114 end
132- ! isa (arg, Const) && make_zero! (arg. dval)
133115 return (nothing , nothing , nothing )
134116 end
135117 end
@@ -343,7 +325,13 @@ for (f, trunc_f, full_f, pb) in (
343325 if ! isa (A, Const)
344326 $ pb (A. dval, Aval, DVval, dDVtrunc, ind)
345327 end
346- ! isa (DV, Const) && make_zero! (DV. dval)
328+ if ! isa (DV, Const)
329+ if A. dval != = DV. dval[1 ]
330+ make_zero! (DV. dval)
331+ else
332+ make_zero! (DV. dval[2 ])
333+ end
334+ end
347335 return (nothing , nothing , nothing )
348336 end
349337end
@@ -379,21 +367,24 @@ for (f!, f_full!, pb!) in (
379367 func:: Const{typeof($f!)} ,
380368 :: Type{RT} ,
381369 cache,
382- A:: Annotation ,
370+ A:: Annotation{TA} ,
383371 D:: Annotation ,
384372 alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
385- ) where {RT}
373+ ) where {RT, TA }
386374 cache_D, dD, V = cache
387375 Dval = something (cache_D, D. val)
388376 # A is NOT used in the pullback, so we assign Aval = nothing
389377 # to trigger an error in case the pullback is modified to directly
390378 # use A (so that whoever does this is forced to handle caching A
391379 # appropriately here)
392380 Aval = nothing
381+ A_is_arg = ! isa (A, Const) && TA <: Diagonal && diagview (A. dval) === D. dval
393382 if ! isa (A, Const)
394- $ pb! (A. dval, Aval, (Diagonal (Dval), V), dD)
383+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
384+ $ pb! (ΔA, Aval, (Diagonal (Dval), V), dD)
385+ A_is_arg && (A. dval .= ΔA)
395386 end
396- ! isa (D, Const) && make_zero! (D. dval)
387+ ! isa (D, Const) && ! A_is_arg && make_zero! (D. dval)
397388 return (nothing , nothing , nothing )
398389 end
399390 end
@@ -425,21 +416,24 @@ function EnzymeRules.reverse(
425416 func:: Const{typeof(svd_vals!)} ,
426417 :: Type{RT} ,
427418 cache,
428- A:: Annotation ,
419+ A:: Annotation{TA} ,
429420 S:: Annotation ,
430421 alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
431- ) where {RT}
422+ ) where {RT, TA }
432423 cache_S, dS, U, Vᴴ = cache
433424 # A is NOT used in the pullback, so we assign Aval = nothing
434425 # to trigger an error in case the pullback is modified to directly
435426 # use A (so that whoever does this is forced to handle caching A
436427 # appropriately here)
437428 Aval = nothing
438429 Sval = something (cache_S, S. val)
430+ A_is_arg = ! isa (A, Const) && TA <: Diagonal && diagview (A. dval) === S. dval
439431 if ! isa (A, Const)
440- svd_vals_pullback! (A. dval, Aval, (U, Diagonal (Sval), Vᴴ), dS)
432+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
433+ svd_vals_pullback! (ΔA, Aval, (U, Diagonal (Sval), Vᴴ), dS)
434+ A_is_arg && (A. dval .= ΔA)
441435 end
442- ! isa (S, Const) && make_zero! (S. dval)
436+ ! isa (S, Const) && ! A_is_arg && make_zero! (S. dval)
443437 return (nothing , nothing , nothing )
444438end
445439
0 commit comments