Skip to content

Commit 37bc892

Browse files
authored
Testsuite for polar (#123)
* Testsuite for polar * Make left polar newton more GPU friendly * Don't test Jacobi on old Julia * Include properly
1 parent 9d1ffb8 commit 37bc892

10 files changed

Lines changed: 151 additions & 269 deletions

File tree

src/implementations/polar.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ copy_input(::typeof(right_polar), A) = copy_input(svd_full, A)
66
function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlgorithm)
77
m, n = size(A)
88
W, P = WP
9-
m >= n ||
10-
throw(ArgumentError("input matrix needs at least as many rows as columns"))
9+
m n ||
10+
throw(ArgumentError("input matrix needs at least as many rows ($m) as columns ($n)"))
1111
@assert W isa AbstractMatrix && P isa AbstractMatrix
1212
@check_size(W, (m, n))
1313
@check_scalar(W, A)
@@ -18,8 +18,8 @@ end
1818
function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::AbstractAlgorithm)
1919
m, n = size(A)
2020
P, Wᴴ = PWᴴ
21-
n >= m ||
22-
throw(ArgumentError("input matrix needs at least as many columns as rows"))
21+
n m ||
22+
throw(ArgumentError("input matrix needs at least as many columns ($n) as rows ($m)"))
2323
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
2424
isempty(P) || @check_size(P, (m, m))
2525
@check_scalar(P, A)
@@ -107,19 +107,19 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
107107
if m > n # initial QR
108108
Q, R = qr_compact!(A)
109109
Rc = view(A, 1:n, 1:n)
110-
copy!(Rc, R)
110+
Rc .= R
111111
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
112112
else # m == n
113113
R = A
114114
Rc = view(W, 1:n, 1:n)
115-
copy!(Rc, R)
115+
Rc .= R
116116
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
117117
end
118118
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
119119
rmul!(R, γ)
120120
rmul!(Rᴴinv, 1 / γ)
121121
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
122-
copy!(Rc, R)
122+
Rc .= R
123123
i = 1
124124
conv = norm(Rᴴinv, Inf)
125125
while i < maxiter && conv > tol
@@ -128,7 +128,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
128128
rmul!(R, γ)
129129
rmul!(Rᴴinv, 1 / γ)
130130
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
131-
copy!(Rc, R)
131+
Rc .= R
132132
conv = norm(Rᴴinv, Inf)
133133
i += 1
134134
end
@@ -152,7 +152,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
152152
else # m == n
153153
L = A
154154
Lc = view(Wᴴ, 1:m, 1:m)
155-
copy!(Lc, L)
155+
Lc .= L
156156
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
157157
end
158158
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
@@ -168,7 +168,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
168168
rmul!(L, γ)
169169
rmul!(Lᴴinv, 1 / γ)
170170
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
171-
copy!(Lc, L)
171+
Lc .= L
172172
conv = norm(Lᴴinv, Inf)
173173
i += 1
174174
end

src/yalapack.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,7 +2162,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21622162
jobu = 'N'
21632163
else
21642164
size(U, 1) == m ||
2165-
throw(DimensionMismatch("row size mismatch between A and U"))
2165+
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
21662166
size(U, 2) >= (range == 'I' ? iu - il + 1 : minmn) ||
21672167
throw(DimensionMismatch("invalid column size of U"))
21682168
jobu = 'V'
@@ -2171,13 +2171,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21712171
jobvt = 'N'
21722172
else
21732173
size(Vᴴ, 2) == n ||
2174-
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
2174+
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
21752175
size(Vᴴ, 1) >= (range == 'I' ? iu - il + 1 : minmn) ||
21762176
throw(DimensionMismatch("invalid row size of Vᴴ"))
21772177
jobvt = 'V'
21782178
end
21792179
length(S) == minmn ||
2180-
throw(DimensionMismatch("length mismatch between A and S"))
2180+
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
21812181

21822182
lda = max(1, stride(A, 2))
21832183
ldu = max(1, stride(U, 2))
@@ -2247,15 +2247,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
22472247
require_one_based_indexing(A, U, Vᴴ, S)
22482248
chkstride1(A, U, Vᴴ, S)
22492249
m, n = size(A)
2250-
m >= n ||
2251-
throw(ArgumentError("gejsv! requires a matrix with at least as many rows as columns"))
2250+
m n ||
2251+
throw(ArgumentError("gejsv! requires a matrix with at least as many rows ($m) as columns ($n)"))
22522252

22532253
joba = 'G'
22542254
if length(U) == 0
22552255
jobu = 'N'
22562256
else
22572257
size(U, 1) == m ||
2258-
throw(DimensionMismatch("row size mismatch between A and U"))
2258+
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
22592259
if size(U, 2) == n
22602260
jobu = 'U'
22612261
elseif size(U, 2) == m
@@ -2268,15 +2268,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
22682268
jobv = 'N'
22692269
else
22702270
size(Vᴴ, 2) == n ||
2271-
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
2271+
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
22722272
if size(Vᴴ, 1) == n
22732273
jobv = 'V'
22742274
else
22752275
throw(DimensionMismatch("invalid row size of Vᴴ"))
22762276
end
22772277
end
22782278
length(S) == n ||
2279-
throw(DimensionMismatch("length mismatch between A and S"))
2279+
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
22802280

22812281
lda = max(1, stride(A, 2))
22822282
mv = Ref{BlasInt}() # unused

test/amd/polar.jl

Lines changed: 0 additions & 83 deletions
This file was deleted.

test/cuda/polar.jl

Lines changed: 0 additions & 83 deletions
This file was deleted.

test/lq.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
3535
TestSuite.test_lq_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
3636
end
3737
end
38-
elseif !is_buildkite
38+
end
39+
if !is_buildkite
3940
if T BLASFloats
4041
TestSuite.test_lq(T, (m, n))
4142
LAPACK_LQ_ALGS = (

0 commit comments

Comments
 (0)