Skip to content

Commit 92f126d

Browse files
Make _matfun and tr GPU tests version-aware
- _matfun: only return Hermitian on Julia 1.12+ where Julia itself does - tr GPU tests: @gpu_broken only on Julia 1.12+ (scalar indexing issue), @gpu on older versions where it works Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 26d16e6 commit 92f126d

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
386386
elseif eltype(fλ) <: Complex
387387
# Complex input with complex output: plain Matrix
388388
fA
389-
elseif A isa Hermitian
389+
elseif A isa Hermitian && VERSION >= v"1.12.0-DEV.0"
390+
# Julia 1.12+ returns Hermitian for Hermitian input with real output
390391
Hermitian(fA)
391392
else
392393
Symmetric(fA)

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,14 @@
138138
test_rrule(logabsdet, -B)
139139
end
140140
@testset "tr" begin
141-
# tr uses scalar indexing in LinearAlgebra, broken on GPU arrays
142-
@gpu_broken test_frule(tr, randn(4, 4))
143-
@gpu_broken test_rrule(tr, randn(4, 4))
141+
if VERSION >= v"1.12.0-DEV.0"
142+
# tr uses scalar indexing in LinearAlgebra on Julia 1.12+, broken on GPU arrays
143+
@gpu_broken test_frule(tr, randn(4, 4))
144+
@gpu_broken test_rrule(tr, randn(4, 4))
145+
else
146+
@gpu test_frule(tr, randn(4, 4))
147+
@gpu test_rrule(tr, randn(4, 4))
148+
end
144149
end
145150
@testset "sylvester" begin
146151
@testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3)

0 commit comments

Comments
 (0)