Skip to content

Commit dc6a765

Browse files
committed
Small fixes
1 parent a22a1c3 commit dc6a765

3 files changed

Lines changed: 13 additions & 13 deletions

File tree

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ if !is_buildkite
2828
@safetestset "Image and Null Space" begin
2929
include("orthnull.jl")
3030
end
31-
@safetestset "ChainRules" begin
32-
include("chainrules.jl")
33-
end
3431
@safetestset "MatrixAlgebraKit.jl" begin
3532
@safetestset "Code quality (Aqua.jl)" begin
3633
using MatrixAlgebraKit
@@ -71,6 +68,9 @@ end
7168
@safetestset "Mooncake" begin
7269
include("mooncake.jl")
7370
end
71+
@safetestset "ChainRules" begin
72+
include("chainrules.jl")
73+
end
7474

7575
using CUDA
7676
if CUDA.functional()

test/testsuite/ad_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function stabilize_eigvals!(D::AbstractVector)
4141
# rescale eigenvalues so that they lie on distinct radii in the complex plane
4242
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
4343
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
44-
hD .= sign.(collect(D)) .* radii[p]
44+
hD = sign.(collect(D)) .* radii[p]
4545
copyto!(D, hD)
4646
return D
4747
end
@@ -138,7 +138,7 @@ function ad_lq_full_setup(A)
138138
Q1 = view(Q, 1:minmn, 1:n)
139139
ΔQ = randn!(similar(A, T, n, n))
140140
ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
141-
ΔQ2 = ΔQ2 * Q1' * Q1
141+
ΔQ2 = (ΔQ2 * Q1') * Q1
142142
ΔL = randn!(similar(A, T, m, n))
143143
return (L, Q), (ΔL, ΔQ)
144144
end

test/testsuite/mooncake.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ function test_mooncake_eig(
275275
test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
276276
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
277277
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
278-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
278+
test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
279279
end
280280
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real))
281281
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
@@ -285,7 +285,7 @@ function test_mooncake_eig(
285285
test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
286286
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
287287
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
288-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
288+
test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
289289
end
290290
end
291291
end
@@ -322,8 +322,8 @@ function test_mooncake_eigh(
322322
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false)
323323
test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
324324
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
325-
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
326-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg)
325+
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
326+
test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
327327
end
328328
D = eigh_vals(A / 2)
329329
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2))
@@ -333,8 +333,8 @@ function test_mooncake_eigh(
333333
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false)
334334
test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
335335
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
336-
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
337-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg)
336+
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
337+
test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
338338
end
339339
end
340340
end
@@ -378,7 +378,7 @@ function test_mooncake_svd(
378378
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
379379
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
380380
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
381-
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
381+
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
382382
end
383383
@testset "trunctol" begin
384384
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2))
@@ -389,7 +389,7 @@ function test_mooncake_svd(
389389
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
390390
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
391391
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
392-
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
392+
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
393393
end
394394
end
395395
end

0 commit comments

Comments
 (0)