@@ -104,17 +104,13 @@ for (f, pb) in (
104104 A_is_arg = A_is_arg1 || A_is_arg2
105105 argval = something (cache_arg, arg. val)
106106 if ! isa (A, Const)
107- if A_is_arg
108- ΔA = make_zero (A. dval)
109- $ pb (ΔA, Aval, argval, darg)
110- A. dval .= ΔA
111- else
112- $ pb (A. dval, Aval, argval, darg)
113- end
107+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
108+ $ pb (ΔA, Aval, argval, darg)
109+ A_is_arg && (A. dval .= ΔA)
114110 end
115111 if ! isa (arg, Const)
116- A . dval === arg . dval[ 1 ] || make_zero! (arg. dval[1 ])
117- A . dval === arg . dval[ 2 ] || make_zero! (arg. dval[2 ])
112+ A_is_arg1 || make_zero! (arg. dval[1 ])
113+ A_is_arg2 || make_zero! (arg. dval[2 ])
118114 end
119115 return (nothing , nothing , nothing )
120116 end
@@ -343,7 +339,7 @@ for (f, trunc_f, full_f, pb) in (
343339 $ pb (A. dval, Aval, DVval, dDVtrunc, ind)
344340 end
345341 if ! isa (DV, Const)
346- if ! ( A. dval === DV. dval[1 ])
342+ if A. dval != = DV. dval[1 ]
347343 make_zero! (DV. dval)
348344 else
349345 make_zero! (DV. dval[2 ])
@@ -397,17 +393,11 @@ for (f!, f_full!, pb!) in (
397393 Aval = nothing
398394 A_is_arg = ! isa (A, Const) && TA <: Diagonal && diagview (A. dval) === D. dval
399395 if ! isa (A, Const)
400- if A_is_arg
401- ΔA = make_zero (A. dval)
402- $ pb! (ΔA, Aval, (Diagonal (Dval), V), dD)
403- A. dval .= ΔA
404- else
405- $ pb! (A. dval, Aval, (Diagonal (Dval), V), dD)
406- end
407- end
408- if ! isa (D, Const) && ! A_is_arg
409- make_zero! (D. dval)
396+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
397+ $ pb! (ΔA, Aval, (Diagonal (Dval), V), dD)
398+ A_is_arg && (A. dval .= ΔA)
410399 end
400+ ! isa (D, Const) && ! A_is_arg && make_zero! (D. dval)
411401 return (nothing , nothing , nothing )
412402 end
413403 end
@@ -452,17 +442,11 @@ function EnzymeRules.reverse(
452442 Sval = something (cache_S, S. val)
453443 A_is_arg = ! isa (A, Const) && TA <: Diagonal && diagview (A. dval) === S. dval
454444 if ! isa (A, Const)
455- if A_is_arg
456- ΔA = make_zero (A. dval)
457- svd_vals_pullback! (ΔA, Aval, (U, Diagonal (Sval), Vᴴ), dS)
458- A. dval .= ΔA
459- else
460- svd_vals_pullback! (A. dval, Aval, (U, Diagonal (Sval), Vᴴ), dS)
461- end
462- end
463- if ! isa (S, Const) && ! A_is_arg
464- make_zero! (S. dval)
445+ ΔA = A_is_arg ? make_zero (A. dval) : A. dval
446+ svd_vals_pullback! (ΔA, Aval, (U, Diagonal (Sval), Vᴴ), dS)
447+ A_is_arg && (A. dval .= ΔA)
465448 end
449+ ! isa (S, Const) && ! A_is_arg && make_zero! (S. dval)
466450 return (nothing , nothing , nothing )
467451end
468452
0 commit comments