Skip to content

Commit 187a4e6

Browse files
committed
A bit of Enzyme cleanup
1 parent cd851d1 commit 187a4e6

1 file changed

Lines changed: 14 additions & 30 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,13 @@ for (f, pb) in (
104104
A_is_arg = A_is_arg1 || A_is_arg2
105105
argval = something(cache_arg, arg.val)
106106
if !isa(A, Const)
107-
if A_is_arg
108-
ΔA = make_zero(A.dval)
109-
$pb(ΔA, Aval, argval, darg)
110-
A.dval .= ΔA
111-
else
112-
$pb(A.dval, Aval, argval, darg)
113-
end
107+
ΔA = A_is_arg ? make_zero(A.dval) : A.dval
108+
$pb(ΔA, Aval, argval, darg)
109+
A_is_arg && (A.dval .= ΔA)
114110
end
115111
if !isa(arg, Const)
116-
A.dval === arg.dval[1] || make_zero!(arg.dval[1])
117-
A.dval === arg.dval[2] || make_zero!(arg.dval[2])
112+
A_is_arg1 || make_zero!(arg.dval[1])
113+
A_is_arg2 || make_zero!(arg.dval[2])
118114
end
119115
return (nothing, nothing, nothing)
120116
end
@@ -343,7 +339,7 @@ for (f, trunc_f, full_f, pb) in (
343339
$pb(A.dval, Aval, DVval, dDVtrunc, ind)
344340
end
345341
if !isa(DV, Const)
346-
if !(A.dval === DV.dval[1])
342+
if A.dval !== DV.dval[1]
347343
make_zero!(DV.dval)
348344
else
349345
make_zero!(DV.dval[2])
@@ -397,17 +393,11 @@ for (f!, f_full!, pb!) in (
397393
Aval = nothing
398394
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval
399395
if !isa(A, Const)
400-
if A_is_arg
401-
ΔA = make_zero(A.dval)
402-
$pb!(ΔA, Aval, (Diagonal(Dval), V), dD)
403-
A.dval .= ΔA
404-
else
405-
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
406-
end
407-
end
408-
if !isa(D, Const) && !A_is_arg
409-
make_zero!(D.dval)
396+
ΔA = A_is_arg ? make_zero(A.dval) : A.dval
397+
$pb!(ΔA, Aval, (Diagonal(Dval), V), dD)
398+
A_is_arg && (A.dval .= ΔA)
410399
end
400+
!isa(D, Const) && !A_is_arg && make_zero!(D.dval)
411401
return (nothing, nothing, nothing)
412402
end
413403
end
@@ -452,17 +442,11 @@ function EnzymeRules.reverse(
452442
Sval = something(cache_S, S.val)
453443
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
454444
if !isa(A, Const)
455-
if A_is_arg
456-
ΔA = make_zero(A.dval)
457-
svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS)
458-
A.dval .= ΔA
459-
else
460-
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
461-
end
462-
end
463-
if !isa(S, Const) && !A_is_arg
464-
make_zero!(S.dval)
445+
ΔA = A_is_arg ? make_zero(A.dval) : A.dval
446+
svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS)
447+
A_is_arg && (A.dval .= ΔA)
465448
end
449+
!isa(S, Const) && !A_is_arg && make_zero!(S.dval)
466450
return (nothing, nothing, nothing)
467451
end
468452

0 commit comments

Comments
 (0)