Skip to content

Commit fd5a0e4

Browse files
authored
Test polar AD for CUDA (#237)
* Test polar AD for CUDA * Move atol and rtol
1 parent 1d5b1d8 commit fd5a0e4

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

test/mooncake/polar.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,23 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
16+
atol = rtol = m * n * TestSuite.precision(T)
1617
if !is_buildkite
17-
atol = rtol = m * n * TestSuite.precision(T)
1818
m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol)
1919
n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol)
2020
#=if m == n
2121
AT = Diagonal{T, Vector{T}}
22-
TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
23-
TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
24-
end=# # broken due to pullback
22+
TestSuite.test_mooncake_left_polar(AT, m; atol, rtol)
23+
TestSuite.test_mooncake_right_polar(AT, m; atol, rtol)
24+
end=#
25+
end
26+
if T BLASFloats && CUDA.functional()
27+
m >= n && TestSuite.test_mooncake_left_polar(CuMatrix{T}, (m, n); atol, rtol)
28+
n >= m && TestSuite.test_mooncake_right_polar(CuMatrix{T}, (m, n); atol, rtol)
29+
#=if m == n
30+
AT = Diagonal{T, CuVector{T}}
31+
TestSuite.test_mooncake_left_polar(AT, m; atol, rtol)
32+
TestSuite.test_mooncake_right_polar(AT, m; atol, rtol)
33+
end=#
2534
end
2635
end

0 commit comments

Comments
 (0)