Skip to content

Commit 517706a

Browse files
committed
LQ/QR working
1 parent 256642b commit 517706a

3 files changed

Lines changed: 14 additions & 3 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ for (f, pb) in (
7070
# if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
7171
# if arg isa Const, ret may still be modified further down the call graph so we should
7272
# copy it to protect ourselves
73-
cache_arg = (arg.val !== ret && arg.val[1] !== A.val) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
73+
A_is_arg1 = !isa(A, Const) && A.val === arg.val[1]
74+
A_is_arg2 = !isa(A, Const) && A.val === arg.val[2]
75+
A_is_arg = A_is_arg1 || A_is_arg2
76+
cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
7477
dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const))
7578
make_zero.(ret)
7679
elseif EnzymeRules.needs_shadow(config)
@@ -96,7 +99,9 @@ for (f, pb) in (
9699
# use A (so that whoever does this is forced to handle caching A
97100
# appropriately here)
98101
Aval = nothing
99-
A_is_arg = !isa(A, Const) && A.dval === arg.dval[1]
102+
A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1]
103+
A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2]
104+
A_is_arg = A_is_arg1 || A_is_arg2
100105
argval = something(cache_arg, arg.val)
101106
if !isa(A, Const)
102107
if A_is_arg
@@ -110,8 +115,10 @@ for (f, pb) in (
110115
if !isa(arg, Const)
111116
if !A_is_arg
112117
make_zero!(arg.dval)
113-
else
118+
elseif A_is_arg1
114119
make_zero!(arg.dval[2])
120+
elseif A_is_arg2
121+
make_zero!(arg.dval[1])
115122
end
116123
end
117124
return (nothing, nothing, nothing)

test/enzyme/lq.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
AT = Diagonal{T, Vector{T}}
19+
m == n && TestSuite.test_enzyme_lq(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1820
end
1921
end

test/enzyme/qr.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
AT = Diagonal{T, Vector{T}}
19+
m == n && TestSuite.test_enzyme_qr(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1820
end
1921
end

0 commit comments

Comments
 (0)