@@ -8,7 +8,7 @@ using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc, ei
88 right_orth, right_polar, svd_compact, svd_full, svd_trunc
99using SparseArraysBase: storedlength
1010using StableRNGs: StableRNG
11- using Test: @test , @test_broken , @ testset
11+ using Test: @test , @testset
1212
1313elts = (Float32, Float64, ComplexF32)
1414arrayts = (Array, JLArray)
@@ -60,50 +60,38 @@ arrayts = (Array, JLArray)
6060 a = BlockSparseMatrix {elt, AbstractMatrix{elt}} (undef, [2 , 3 ], [2 , 3 ])
6161 a[Block (1 , 1 )] = dev (randn (rng, elt, 2 , 2 ))
6262 for f in (eig_full, eig_trunc)
63- if arrayt === Array
63+ @test begin
6464 d, v = f (a)
65- @test a * v ≈ v * d
66- else
67- @test_broken f (a)
68- end
65+ a * v ≈ v * d
66+ end broken = arrayt ≢ Array
6967 end
70- if arrayt === Array
68+ @test begin
7169 d = eig_vals (a)
72- @test sort (Vector (d); by = abs) ≈ sort (eig_vals (Matrix (a)); by = abs)
73- else
74- @test_broken eig_vals (a)
75- end
70+ sort (Vector (d); by = abs) ≈ sort (eig_vals (Matrix (a)); by = abs)
71+ end broken = arrayt ≢ Array
7672
7773 rng = StableRNG (1234 )
7874 a = BlockSparseMatrix {elt, AbstractMatrix{elt}} (undef, [2 , 3 ], [2 , 3 ])
7975 a[Block (1 , 1 )] = dev (parent (hermitianpart (randn (rng, elt, 2 , 2 ))))
8076 for f in (eigh_full, eigh_trunc)
81- if arrayt === Array
77+ @test begin
8278 d, v = f (a)
83- @test a * v ≈ v * d
84- else
85- @test_broken f (a)
86- end
79+ a * v ≈ v * d
80+ end broken = arrayt ≢ Array
8781 end
88- if arrayt === Array
82+ @test begin
8983 d = eigh_vals (a)
90- @test sort (Vector (d); by = abs) ≈ sort (eig_vals (Matrix (a)); by = abs)
91- else
92- @test_broken eigh_vals (a)
93- end
84+ sort (Vector (d); by = abs) ≈ sort (eig_vals (Matrix (a)); by = abs)
85+ end broken = arrayt ≢ Array
9486
9587 rng = StableRNG (1234 )
9688 a = BlockSparseMatrix {elt, AbstractMatrix{elt}} (undef, [2 , 3 ], [2 , 3 ])
9789 a[Block (1 , 1 )] = dev (randn (rng, elt, 2 , 2 ))
9890 for f in (left_orth, left_polar, qr_compact, qr_full)
9991 u, c = f (a)
10092 @test u * c ≈ a
101- if arrayt ≡ Array
102- @test isisometric (u; side = :left )
103- else
104- # TODO : Fix comparison with UniformScaling on GPU.
105- @test_broken isisometric (u; side = :left )
106- end
93+ # TODO : Fix comparison with UniformScaling on GPU.
94+ @test isisometric (u; side = :left ) broken = arrayt ≢ Array
10795 end
10896 for f in (right_orth, right_polar, lq_compact, lq_full)
10997 c, u = f (a)
@@ -112,24 +100,15 @@ arrayts = (Array, JLArray)
112100 # reproduce `a` (regression in MatrixAlgebraKit 0.6.6, tracked at
113101 # https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/issues/218).
114102 @test c * u ≈ a broken = (arrayt ≢ Array && f ∈ (right_orth, lq_compact, lq_full))
115- if arrayt ≡ Array
116- @test isisometric (u; side = :right )
117- else
118- # TODO : Fix comparison with UniformScaling on GPU.
119- @test_broken isisometric (u; side = :right )
120- end
103+ # TODO : Fix comparison with UniformScaling on GPU.
104+ @test isisometric (u; side = :right ) broken = arrayt ≢ Array
121105 end
122106 for f in (svd_compact, svd_full, svd_trunc)
123- if arrayt ≢ Array && f ≡ svd_trunc
124- # `svd_trunc` on JLArray-backed `BlockSparseMatrix{T, AbstractMatrix{T}}`
125- # currently triggers scalar indexing on the GPU array.
126- @test begin
127- u, s, v = f (a)
128- u * s * v ≈ a
129- end broken = true
130- else
107+ # `svd_trunc` on JLArray-backed `BlockSparseMatrix{T, AbstractMatrix{T}}`
108+ # currently triggers scalar indexing on the GPU array.
109+ @test begin
131110 u, s, v = f (a)
132- @test u * s * v ≈ a
133- end
111+ u * s * v ≈ a
112+ end broken = (arrayt ≢ Array && f ≡ svd_trunc)
134113 end
135114end
0 commit comments