Skip to content

Commit e0534f5

Browse files
committed
Diag tests working for Mooncake
1 parent afcddf2 commit e0534f5

6 files changed

Lines changed: 23 additions & 11 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,24 @@ for (f!, f, pb, adj) in (
4040
arg2c = copy(arg2)
4141
$f!(A, args, Mooncake.primal(alg_dalg))
4242
function $adj(::NoRData)
43-
copy!(A, Ac)
43+
# DON'T copy Ac to A if A === one
44+
# of the output args -- this can
45+
# mess up the pullback because
46+
# generally the args are used there
4447
if !(A === arg1 || A === arg2)
48+
copy!(A, Ac)
4549
$pb(dA, A, (arg1, arg2), (darg1, darg2))
4650
else
4751
ΔA = zero(A)
4852
$pb(ΔA, A, (arg1, arg2), (darg1, darg2))
4953
dA .= ΔA
5054
end
5155
if A === arg1
56+
copy!(A, Ac)
5257
zero!(darg2)
5358
copy!(arg2, arg2c)
5459
elseif A === arg2
60+
copy!(A, Ac)
5561
zero!(darg1)
5662
copy!(arg1, arg1c)
5763
else
@@ -60,7 +66,7 @@ for (f!, f, pb, adj) in (
6066
copy!(arg2, arg2c)
6167
copy!(arg1, arg1c)
6268
end
63-
return NoRData(), NoRData(), NoRData(), NoRData()
69+
return ntuple(Returns(NoRData()), 4)
6470
end
6571
return args_dargs, $adj
6672
end

test/mooncake/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1818
AT = Diagonal{T, Vector{T}}
19-
TestSuite.test_mooncake_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
19+
TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
2020
end
2121
end

test/mooncake/lq.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18-
#=if m == n
18+
if m == n
1919
AT = Diagonal{T, Vector{T}}
20-
TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
21-
end=# # broken with singular exception
20+
TestSuite.test_mooncake_lq(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
21+
end
2222
end
2323
end

test/mooncake/orthnull.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18-
#=if m == n
18+
if m == n
1919
AT = Diagonal{T, Vector{T}}
2020
TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
21-
end=#
21+
end
2222
end
2323
end

test/mooncake/qr.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18-
#=if m == n
18+
if m == n
1919
AT = Diagonal{T, Vector{T}}
20-
TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
21-
end=#
20+
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
21+
end
2222
end
2323
end

test/testsuite/TestSuite.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz..
9696
return mul!(A, V, C)
9797
end
9898

99+
function instantiate_rank_deficient_matrix(::Type{T}, sz; trunc = truncrank(div(min(sz...), 2))) where {T <: Diagonal}
100+
A = instantiate_matrix(eltype(T), sz)
101+
V, C = left_orth!(A; trunc)
102+
return Diagonal(diag(mul!(A, V, C)))
103+
end
104+
99105
include("ad_utils.jl")
100106

101107
include("projections.jl")

0 commit comments

Comments
 (0)