Skip to content

Commit 742572f

Browse files
authored
Test orthnull with polar in fwd mode (#246)
1 parent 183370d commit 742572f

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

test/testsuite/enzyme/orthnull.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
"""
1818
test_enzyme_left_orth(T, sz; rng, atol, rtol)
1919
20-
Test the Enzyme reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
20+
Test the Enzyme forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
2121
algorithms, and their in-place variants.
2222
"""
2323
function test_enzyme_left_orth(
@@ -44,6 +44,9 @@ function test_enzyme_left_orth(
4444
VC, ΔVC = ad_left_orth_setup(A)
4545
test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC)
4646
test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC)
47+
A = instantiate_matrix(T, sz)
48+
test_forward(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm)
49+
test_forward(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
4750
end
4851
end
4952
end
@@ -52,7 +55,7 @@ end
5255
"""
5356
test_enzyme_right_orth(T, sz; rng, atol, rtol)
5457
55-
Test the Enzyme reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
58+
Test the Enzyme forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
5659
algorithms, and their in-place variants.
5760
"""
5861
function test_enzyme_right_orth(
@@ -78,6 +81,9 @@ function test_enzyme_right_orth(
7881
CVᴴ, ΔCVᴴ = ad_right_orth_setup(A)
7982
test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ)
8083
test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ)
84+
A = instantiate_matrix(T, sz)
85+
test_forward(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm)
86+
test_forward(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
8187
end
8288
end
8389
end

test/testsuite/mooncake/orthnull.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
"""
1818
test_mooncake_left_orth(T, sz; rng, atol, rtol)
1919
20-
Test the Mooncake reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
20+
Test the Mooncake forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
2121
algorithms, and their in-place variants.
2222
"""
2323
function test_mooncake_left_orth(
@@ -51,11 +51,11 @@ function test_mooncake_left_orth(
5151

5252
Mooncake.TestUtils.test_rule(
5353
rng, left_orth, A, alg;
54-
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
54+
output_tangent, is_primitive = false, atol, rtol
5555
)
5656
Mooncake.TestUtils.test_rule(
5757
rng, call_and_zero!, left_orth!, A, alg;
58-
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
58+
output_tangent, is_primitive = false, atol, rtol
5959
)
6060
end
6161
end
@@ -65,7 +65,7 @@ end
6565
"""
6666
test_mooncake_right_orth(T, sz; rng, atol, rtol)
6767
68-
Test the Mooncake reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
68+
Test the Mooncake forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
6969
algorithms, and their in-place variants.
7070
"""
7171
function test_mooncake_right_orth(
@@ -99,11 +99,11 @@ function test_mooncake_right_orth(
9999

100100
Mooncake.TestUtils.test_rule(
101101
rng, right_orth, A, alg;
102-
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
102+
output_tangent, is_primitive = false, atol, rtol
103103
)
104104
Mooncake.TestUtils.test_rule(
105105
rng, call_and_zero!, right_orth!, A, alg;
106-
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
106+
output_tangent, is_primitive = false, atol, rtol
107107
)
108108
end
109109
end

0 commit comments

Comments
 (0)