Skip to content

Commit 8152728

Browse files
mtfishmanclaude
andcommitted
Replace @test_broken with @test ... broken = ...
Fold each `if arrayt === Array; @test ...; else @test_broken ...; end` pair into a single `@test expr broken = (arrayt ≢ Array && ...)` form so the array-type branching lives in the `broken` argument rather than duplicated control flow. Drop the now-unused `@test_broken` import. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c40206f commit 8152728

1 file changed

Lines changed: 22 additions & 43 deletions

File tree

test/test_abstract_blocktype.jl

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, ei
88
right_orth, right_polar, svd_compact, svd_full, svd_trunc
99
using SparseArraysBase: storedlength
1010
using StableRNGs: StableRNG
11-
using Test: @test, @test_broken, @testset
11+
using Test: @test, @testset
1212

1313
elts = (Float32, Float64, ComplexF32)
1414
arrayts = (Array, JLArray)
@@ -60,50 +60,38 @@ arrayts = (Array, JLArray)
6060
a = BlockSparseMatrix{elt, AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
6161
a[Block(1, 1)] = dev(randn(rng, elt, 2, 2))
6262
for f in (eig_full, eig_trunc)
63-
if arrayt === Array
63+
@test begin
6464
d, v = f(a)
65-
@test a * v v * d
66-
else
67-
@test_broken f(a)
68-
end
65+
a * v v * d
66+
end broken = arrayt Array
6967
end
70-
if arrayt === Array
68+
@test begin
7169
d = eig_vals(a)
72-
@test sort(Vector(d); by = abs) sort(eig_vals(Matrix(a)); by = abs)
73-
else
74-
@test_broken eig_vals(a)
75-
end
70+
sort(Vector(d); by = abs) sort(eig_vals(Matrix(a)); by = abs)
71+
end broken = arrayt Array
7672

7773
rng = StableRNG(1234)
7874
a = BlockSparseMatrix{elt, AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
7975
a[Block(1, 1)] = dev(parent(hermitianpart(randn(rng, elt, 2, 2))))
8076
for f in (eigh_full, eigh_trunc)
81-
if arrayt === Array
77+
@test begin
8278
d, v = f(a)
83-
@test a * v v * d
84-
else
85-
@test_broken f(a)
86-
end
79+
a * v v * d
80+
end broken = arrayt Array
8781
end
88-
if arrayt === Array
82+
@test begin
8983
d = eigh_vals(a)
90-
@test sort(Vector(d); by = abs) sort(eig_vals(Matrix(a)); by = abs)
91-
else
92-
@test_broken eigh_vals(a)
93-
end
84+
sort(Vector(d); by = abs) sort(eig_vals(Matrix(a)); by = abs)
85+
end broken = arrayt Array
9486

9587
rng = StableRNG(1234)
9688
a = BlockSparseMatrix{elt, AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
9789
a[Block(1, 1)] = dev(randn(rng, elt, 2, 2))
9890
for f in (left_orth, left_polar, qr_compact, qr_full)
9991
u, c = f(a)
10092
@test u * c a
101-
if arrayt Array
102-
@test isisometric(u; side = :left)
103-
else
104-
# TODO: Fix comparison with UniformScaling on GPU.
105-
@test_broken isisometric(u; side = :left)
106-
end
93+
# TODO: Fix comparison with UniformScaling on GPU.
94+
@test isisometric(u; side = :left) broken = arrayt Array
10795
end
10896
for f in (right_orth, right_polar, lq_compact, lq_full)
10997
c, u = f(a)
@@ -112,24 +100,15 @@ arrayts = (Array, JLArray)
112100
# reproduce `a` (regression in MatrixAlgebraKit 0.6.6, tracked at
113101
# https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/issues/218).
114102
@test c * u a broken = (arrayt Array && f (right_orth, lq_compact, lq_full))
115-
if arrayt Array
116-
@test isisometric(u; side = :right)
117-
else
118-
# TODO: Fix comparison with UniformScaling on GPU.
119-
@test_broken isisometric(u; side = :right)
120-
end
103+
# TODO: Fix comparison with UniformScaling on GPU.
104+
@test isisometric(u; side = :right) broken = arrayt Array
121105
end
122106
for f in (svd_compact, svd_full, svd_trunc)
123-
if arrayt Array && f svd_trunc
124-
# `svd_trunc` on JLArray-backed `BlockSparseMatrix{T, AbstractMatrix{T}}`
125-
# currently triggers scalar indexing on the GPU array.
126-
@test begin
127-
u, s, v = f(a)
128-
u * s * v a
129-
end broken = true
130-
else
107+
# `svd_trunc` on JLArray-backed `BlockSparseMatrix{T, AbstractMatrix{T}}`
108+
# currently triggers scalar indexing on the GPU array.
109+
@test begin
131110
u, s, v = f(a)
132-
@test u * s * v a
133-
end
111+
u * s * v a
112+
end broken = (arrayt Array && f svd_trunc)
134113
end
135114
end

0 commit comments

Comments
 (0)