Skip to content

Commit 750529c

Browse files
committed
Diagonal + eig(h) working
1 parent 415f3f4 commit 750529c

3 files changed

Lines changed: 50 additions & 45 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,6 @@ using LinearAlgebra
4343
# other provided input variable.
4444
#--------------------------------------------
4545

46-
# this rule is necessary for now as without it,
47-
# a segfault occurs both on 1.10 and 1.12 -- likely
48-
# a deeper internal bug
49-
function EnzymeRules.augmented_primal(
50-
config::EnzymeRules.RevConfigWidth{1},
51-
func::Const{typeof(copy_input)},
52-
::Type{RT},
53-
f::Annotation,
54-
A::Annotation
55-
) where {RT}
56-
ret = func.val(f.val, A.val)
57-
primal = EnzymeRules.needs_primal(config) ? ret : nothing
58-
shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing
59-
return EnzymeRules.AugmentedReturn(primal, shadow, shadow)
60-
end
61-
62-
function EnzymeRules.reverse(
63-
config::EnzymeRules.RevConfigWidth{1},
64-
func::Const{typeof(copy_input)},
65-
::Type{RT},
66-
cache,
67-
f::Annotation,
68-
A::Annotation
69-
) where {RT}
70-
copy_shadow = cache
71-
if !isa(A, Const) && !isnothing(copy_shadow)
72-
A.dval .+= copy_shadow
73-
end
74-
return (nothing, nothing)
75-
end
76-
7746
# two-argument factorizations like LQ, QR, EIG
7847
for (f, pb) in (
7948
(qr_full!, qr_pullback!),
@@ -101,9 +70,11 @@ for (f, pb) in (
10170
# if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
10271
# if arg isa Const, ret may still be modified further down the call graph so we should
10372
# copy it to protect ourselves
104-
cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
105-
dret = if EnzymeRules.needs_shadow(config)
106-
(TA == Nothing && TB == Nothing) || isa(arg, Const) ? zero.(ret) : arg.dval
73+
cache_arg = (arg.val !== ret && arg.val[1] !== A.val) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
74+
dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const))
75+
make_zero.(ret)
76+
elseif EnzymeRules.needs_shadow(config)
77+
arg.dval
10778
else
10879
nothing
10980
end
@@ -125,11 +96,24 @@ for (f, pb) in (
12596
# use A (so that whoever does this is forced to handle caching A
12697
# appropriately here)
12798
Aval = nothing
99+
A_is_arg = !isa(A, Const) && A.dval === arg.dval[1]
128100
argval = something(cache_arg, arg.val)
129101
if !isa(A, Const)
130-
$pb(A.dval, Aval, argval, darg)
102+
if A_is_arg
103+
ΔA = make_zero(A.dval)
104+
$pb(ΔA, Aval, argval, darg)
105+
A.dval .= ΔA
106+
else
107+
$pb(A.dval, Aval, argval, darg)
108+
end
109+
end
110+
if !isa(arg, Const)
111+
if !A_is_arg
112+
make_zero!(arg.dval)
113+
else
114+
make_zero!(arg.dval[2])
115+
end
131116
end
132-
!isa(arg, Const) && make_zero!(arg.dval)
133117
return (nothing, nothing, nothing)
134118
end
135119
end
@@ -356,7 +340,13 @@ for (f, trunc_f, full_f, pb) in (
356340
if !isa(A, Const)
357341
$pb(A.dval, Aval, DVval, dDVtrunc, ind)
358342
end
359-
!isa(DV, Const) && make_zero!(DV.dval)
343+
if !isa(DV, Const)
344+
if !(A.dval === DV.dval[1])
345+
make_zero!(DV.dval)
346+
else
347+
make_zero!(DV.dval[2])
348+
end
349+
end
360350
return (nothing, nothing, nothing)
361351
end
362352
end
@@ -392,21 +382,30 @@ for (f!, f_full!, pb!) in (
392382
func::Const{typeof($f!)},
393383
::Type{RT},
394384
cache,
395-
A::Annotation,
385+
A::Annotation{TA},
396386
D::Annotation,
397387
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
398-
) where {RT}
388+
) where {RT, TA}
399389
cache_D, dD, V = cache
400390
Dval = something(cache_D, D.val)
401391
# A is NOT used in the pullback, so we assign Aval = nothing
402392
# to trigger an error in case the pullback is modified to directly
403393
# use A (so that whoever does this is forced to handle caching A
404394
# appropriately here)
405395
Aval = nothing
396+
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval
406397
if !isa(A, Const)
407-
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
398+
if A_is_arg
399+
ΔA = make_zero(A.dval)
400+
$pb!(ΔA, Aval, (Diagonal(Dval), V), dD)
401+
A.dval .= ΔA
402+
else
403+
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
404+
end
405+
end
406+
if !isa(D, Const) && !A_is_arg
407+
make_zero!(D.dval)
408408
end
409-
!isa(D, Const) && make_zero!(D.dval)
410409
return (nothing, nothing, nothing)
411410
end
412411
end
@@ -438,10 +437,10 @@ function EnzymeRules.reverse(
438437
func::Const{typeof(svd_vals!)},
439438
::Type{RT},
440439
cache,
441-
A::Annotation,
440+
A::Annotation{TA},
442441
S::Annotation,
443442
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
444-
) where {RT}
443+
) where {RT, TA}
445444
cache_S, dS, U, Vᴴ = cache
446445
# A is NOT used in the pullback, so we assign Aval = nothing
447446
# to trigger an error in case the pullback is modified to directly
@@ -452,7 +451,9 @@ function EnzymeRules.reverse(
452451
if !isa(A, Const)
453452
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
454453
end
455-
!isa(S, Const) && make_zero!(S.dval)
454+
if !isa(S, Const) && !(TA <: Diagonal && (diagview(A.dval) === S.dval))
455+
make_zero!(S.dval)
456+
end
456457
return (nothing, nothing, nothing)
457458
end
458459

test/enzyme/eig.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
AT = Diagonal{T, Vector{T}}
19+
TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1820
end
1921
end

test/enzyme/eigh.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ m = 19
1414
for T in (BLASFloats..., GenericFloats...)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
17-
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
17+
#TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
AT = Diagonal{T, Vector{T}}
19+
TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1820
end
1921
end

0 commit comments

Comments
 (0)