Skip to content

Commit aff3314

Browse files
committed
More tests for CUDA and AMD
1 parent 0bac889 commit aff3314

7 files changed

Lines changed: 747 additions & 0 deletions

File tree

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ Zygote = "0.7"
3232
julia = "1.10"
3333

3434
[extras]
35+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3536
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3637
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
38+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3739
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3840
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3941
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -43,3 +45,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4345

4446
[targets]
4547
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
48+
49+
[sources]
50+
CUDA = {url = "https://github.com/JuliaGPU/CUDA.jl", rev = "master"}
51+
AMDGPU = {url = "https://github.com/JuliaGPU/AMDGPU.jl", rev = "master"}

src/common/matrixproperties.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ See also [`isisometry`](@ref) and [`is_right_isometry`](@ref).
4242
""" is_left_isometry
4343

4444
function is_left_isometry(A::AbstractMatrix; isapprox_kwargs...)
45+
iszero(min(size(A)...)) && return true
4546
return isapprox(A' * A, LinearAlgebra.I; isapprox_kwargs...)
4647
end
4748

@@ -55,5 +56,6 @@ See also [`isisometry`](@ref) and [`is_left_isometry`](@ref).
5556
""" is_right_isometry
5657

5758
function is_right_isometry(A::AbstractMatrix; isapprox_kwargs...)
59+
iszero(min(size(A)...)) && return true
5860
return isapprox(A * A', LinearAlgebra.I; isapprox_kwargs...)
5961
end

test/amd/orthnull.jl

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, mul!
6+
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
7+
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
8+
initialize_output, AbstractAlgorithm
9+
using AMDGPU
10+
11+
# Used to test non-AbstractMatrix codepaths.
12+
struct LinearMap{P<:AbstractMatrix}
13+
parent::P
14+
end
15+
Base.parent(A::LinearMap) = getfield(A, :parent)
16+
function Base.copy!(dest::LinearMap, src::LinearMap)
17+
copy!(parent(dest), parent(src))
18+
return dest
19+
end
20+
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
21+
mul!(parent(C), parent(A), parent(B))
22+
return C
23+
end
24+
25+
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
26+
return LinearMap(copy_input(qr_compact, parent(A)))
27+
end
28+
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
29+
return LinearMap(copy_input(lq_compact, parent(A)))
30+
end
31+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
32+
return LinearMap.(initialize_output(left_orth!, parent(A)))
33+
end
34+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
35+
return LinearMap.(initialize_output(right_orth!, parent(A)))
36+
end
37+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
38+
return check_input(left_orth!, parent(A), parent.(VC), alg)
39+
end
40+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
41+
return check_input(right_orth!, parent(A), parent.(VC), alg)
42+
end
43+
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
44+
return default_svd_algorithm(A; kwargs...)
45+
end
46+
function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap,
47+
alg::GPU_SVDAlgorithm)
48+
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
49+
end
50+
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
51+
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
52+
end
53+
54+
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
55+
rng = StableRNG(123)
56+
m = 54
57+
for n in (37, m, 63)
58+
minmn = min(m, n)
59+
A = ROCArray(randn(rng, T, m, n))
60+
V, C = @constinferred left_orth(A)
61+
N = @constinferred left_null(A)
62+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
63+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
64+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
65+
@test V * C A
66+
@test isisometry(V)
67+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
68+
@test isisometry(N)
69+
@test V * V' + N * N' I
70+
71+
M = LinearMap(A)
72+
VM, CM = @constinferred left_orth(M; kind=:svd)
73+
@test parent(VM) * parent(CM) A
74+
75+
if m > n
76+
nullity = 5
77+
V, C = @constinferred left_orth(A)
78+
# doesn't work because of truncation
79+
#N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
80+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
81+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
82+
#@test N isa ROCMatrix{T} && size(N) == (m, nullity)
83+
@test V * C A
84+
@test isisometry(V)
85+
#@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
86+
#@test isisometry(N)
87+
end
88+
89+
for alg_qr in ((; positive=true), (; positive=false), rocSOLVER_HouseholderQR())
90+
V, C = @constinferred left_orth(A; alg_qr)
91+
N = @constinferred left_null(A; alg_qr)
92+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
93+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
94+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
95+
@test V * C A
96+
@test isisometry(V)
97+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
98+
@test isisometry(N)
99+
@test V * V' + N * N' I
100+
end
101+
102+
Ac = similar(A)
103+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
104+
N2 = @constinferred left_null!(copy!(Ac, A), N)
105+
@test V2 === V
106+
@test C2 === C
107+
@test N2 === N
108+
@test V2 * C2 A
109+
@test isisometry(V2)
110+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
111+
@test isisometry(N2)
112+
@test V2 * V2' + N2 * N2' I
113+
114+
atol = eps(real(T))
115+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
116+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; atol=atol))
117+
#@test V2 !== V
118+
#@test C2 !== C
119+
@test N2 !== C
120+
#@test V2 * C2 ≈ A
121+
#@test isisometry(V2)
122+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
123+
@test isisometry(N2)
124+
#@test V2 * V2' + N2 * N2' ≈ I
125+
126+
rtol = eps(real(T))
127+
for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)),
128+
(TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)))
129+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
130+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null)
131+
#@test V2 !== V
132+
#@test C2 !== C
133+
@test N2 !== C
134+
#@test V2 * C2 ≈ A
135+
#@test isisometry(V2)
136+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
137+
@test isisometry(N2)
138+
#@test V2 * V2' + N2 * N2' ≈ I
139+
end
140+
141+
for kind in (:qr, :polar, :svd) # explicit kind kwarg
142+
m < n && kind == :polar && continue
143+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind)
144+
@test V2 === V
145+
@test C2 === C
146+
@test V2 * C2 A
147+
@test isisometry(V2)
148+
if kind != :polar
149+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind)
150+
@test N2 === N
151+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
152+
@test isisometry(N2)
153+
@test V2 * V2' + N2 * N2' I
154+
end
155+
156+
# with kind and tol kwargs
157+
if kind == :svd
158+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
159+
trunc=(; atol=atol))
160+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
161+
trunc=(; atol=atol))
162+
@test V2 !== V
163+
@test C2 !== C
164+
@test N2 !== C
165+
@test V2 * C2 A
166+
@test V2' * V2 I
167+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
168+
@test N2' * N2 I
169+
@test V2 * V2' + N2 * N2' I
170+
171+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
172+
trunc=(; rtol=rtol))
173+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
174+
trunc=(; rtol=rtol))
175+
@test V2 !== V
176+
@test C2 !== C
177+
@test N2 !== C
178+
@test V2 * C2 A
179+
@test isisometry(V2)
180+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
181+
@test isisometry(N2)
182+
@test V2 * V2' + N2 * N2' I
183+
else
184+
@test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind,
185+
trunc=(; atol=atol))
186+
@test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind,
187+
trunc=(; rtol=rtol))
188+
@test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind,
189+
trunc=(; atol=atol))
190+
@test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind,
191+
trunc=(; rtol=rtol))
192+
end
193+
end
194+
end
195+
end
196+
197+
@testset "right_orth and right_null for T = $T" for T in (Float32, Float64, ComplexF32,
198+
ComplexF64)
199+
rng = StableRNG(123)
200+
m = 54
201+
@testset for n in (37, m, 63)
202+
minmn = min(m, n)
203+
A = ROCArray(randn(rng, T, m, n))
204+
C, Vᴴ = @constinferred right_orth(A)
205+
Nᴴ = @constinferred right_null(A)
206+
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
207+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
208+
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
209+
@test C * Vᴴ A
210+
@test isisometry(Vᴴ; side=:right)
211+
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
212+
@test isisometry(Nᴴ; side=:right)
213+
@test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ I
214+
215+
M = LinearMap(A)
216+
CM, VMᴴ = @constinferred right_orth(M; kind=:svd)
217+
@test parent(CM) * parent(VMᴴ) A
218+
219+
Ac = similar(A)
220+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
221+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
222+
@test C2 === C
223+
@test Vᴴ2 === Vᴴ
224+
@test Nᴴ2 === Nᴴ
225+
@test C2 * Vᴴ2 A
226+
@test isisometry(Vᴴ2; side=:right)
227+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
228+
@test isisometry(Nᴴ; side=:right)
229+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I atol = MatrixAlgebraKit.defaulttol(T)
230+
231+
# TODO truncate currently broken due to searchsortedlast
232+
atol = eps(real(T))
233+
rtol = eps(real(T))
234+
#=C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
235+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol))
236+
@test C2 !== C
237+
@test Vᴴ2 !== Vᴴ
238+
@test Nᴴ2 !== Nᴴ
239+
@test C2 * Vᴴ2 ≈ A
240+
@test isisometry(Vᴴ2; side=:right)
241+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
242+
@test isisometry(Nᴴ; side=:right)
243+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
244+
245+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
246+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol))
247+
@test C2 !== C
248+
@test Vᴴ2 !== Vᴴ
249+
@test Nᴴ2 !== Nᴴ
250+
@test C2 * Vᴴ2 ≈ A
251+
@test isisometry(Vᴴ2; side=:right)
252+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
253+
@test isisometry(Nᴴ2; side=:right)
254+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
255+
=#
256+
257+
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
258+
n < m && kind == :polar && continue
259+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind)
260+
@test C2 === C
261+
@test Vᴴ2 === Vᴴ
262+
A2 = C2 * Vᴴ2
263+
@test A2 A
264+
@test isisometry(Vᴴ2; side=:right)
265+
if kind != :polar
266+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind)
267+
@test Nᴴ2 === Nᴴ
268+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
269+
@test isisometry(Nᴴ2; side=:right)
270+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
271+
end
272+
273+
if kind == :svd
274+
# doesn't work yet because of searchsortedfirst
275+
#= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
276+
trunc=(; atol=atol))
277+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
278+
trunc=(; atol=atol))
279+
@test C2 !== C
280+
@test Vᴴ2 !== Vᴴ
281+
@test Nᴴ2 !== Nᴴ
282+
@test C2 * Vᴴ2 ≈ A
283+
@test isisometry(Vᴴ2; side=:right)
284+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
285+
@test isisometry(Nᴴ2; side=:right)
286+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
287+
288+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
289+
trunc=(; rtol=rtol))
290+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
291+
trunc=(; rtol=rtol))
292+
@test C2 !== C
293+
@test Vᴴ2 !== Vᴴ
294+
@test Nᴴ2 !== Nᴴ
295+
@test C2 * Vᴴ2 ≈ A
296+
@test isisometry(Vᴴ2; side=:right)
297+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
298+
@test isisometry(Nᴴ2; side=:right)
299+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
300+
=#
301+
else
302+
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
303+
trunc=(; atol=atol))
304+
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
305+
trunc=(; rtol=rtol))
306+
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,
307+
trunc=(; atol=atol))
308+
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,
309+
trunc=(; rtol=rtol))
310+
end
311+
end
312+
end
313+
end

0 commit comments

Comments
 (0)