Skip to content

Commit c7efc19

Browse files
committed
Let's try with BigFloat
1 parent daeab4d commit c7efc19

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

test/mooncake.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ using CUDA, AMDGPU
55

66
#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
77
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
8-
8+
GenericFloats = (BigFloat,)
99
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
1010
using .TestSuite
1111

1212
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1313

1414
m = 19
15-
for T in BLASFloats, n in (17, m, 23)
15+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1616
TestSuite.seed_rng!(123)
1717
if CUDA.functional()
1818
TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))

test/testsuite/mooncake.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using MatrixAlgebraKit
33
using Mooncake, Mooncake.TestUtils
44
using Mooncake: rrule!!
55
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc
6+
using LinearAlgebra: BlasFloat
7+
using GenericLinearAlgebra
68

79
function mc_copy_eigh_full(A; kwargs...)
810
A = (A + A') / 2
@@ -160,7 +162,7 @@ function test_mooncake(T::Type, sz; kwargs...)
160162
test_mooncake_qr(T, sz; kwargs...)
161163
test_mooncake_lq(T, sz; kwargs...)
162164
if length(sz) == 1 || sz[1] == sz[2]
163-
test_mooncake_eig(T, sz; kwargs...)
165+
T <: BlasFloat && test_mooncake_eig(T, sz; kwargs...)
164166
test_mooncake_eigh(T, sz; kwargs...)
165167
end
166168
test_mooncake_svd(T, sz; kwargs...)

0 commit comments

Comments
 (0)