Skip to content

Commit 5fc5ce5

Browse files
authored
some yalapack fixes and test tryouts (#135)
1 parent b6eaf3d commit 5fc5ce5

2 files changed

Lines changed: 76 additions & 74 deletions

File tree

src/yalapack.jl

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module YALAPACK # Yet another lapack wrapper
1010

1111
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
1212
LAPACKException, SingularException, PosDefException, checksquare, chkstride1,
13-
require_one_based_indexing, triu!, isposdef, adjoint!
13+
require_one_based_indexing, triu!, isposdef, adjoint!, rmul!
1414

1515
using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
1616
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror
@@ -20,66 +20,66 @@ const BlasMat{T <: BlasFloat} = StridedMatrix{T}
2020
# type alias for matrices that are possibly supported by YALAPACK, after conversion
2121
const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}}
2222

23-
# LU factorisation
24-
for (getrf, getrs, elty) in (
25-
(:dgetrf_, :dgetrs_, :Float64),
26-
(:sgetrf_, :sgetrs_, :Float32),
27-
(:zgetrf_, :zgetrs_, :ComplexF64),
28-
(:cgetrf_, :cgetrs_, :ComplexF32),
29-
)
30-
@eval begin
31-
function getrf!(
32-
A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt};
33-
check::Bool = true
34-
)
35-
require_one_based_indexing(A, ipiv)
36-
chkstride1(A, ipiv)
37-
chkfinite(A)
38-
m, n = size(A)
23+
# LU factorisation (currently unused in MatrixAlgebraKit)
24+
# for (getrf, getrs, elty) in (
25+
# (:dgetrf_, :dgetrs_, :Float64),
26+
# (:sgetrf_, :sgetrs_, :Float32),
27+
# (:zgetrf_, :zgetrs_, :ComplexF64),
28+
# (:cgetrf_, :cgetrs_, :ComplexF32),
29+
# )
30+
# @eval begin
31+
# function getrf!(
32+
# A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt};
33+
# check::Bool = true
34+
# )
35+
# require_one_based_indexing(A, ipiv)
36+
# chkstride1(A, ipiv)
37+
# chkfinite(A)
38+
# m, n = size(A)
3939

40-
lda = max(1, stride(A, 2))
41-
info = Ref{BlasInt}()
42-
ccall(
43-
(@blasfunc($getrf), libblastrampoline), Cvoid,
44-
(
45-
Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
46-
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
47-
),
48-
m, n, A, lda, ipiv, info
49-
)
50-
chkargsok(info[])
51-
return A, ipiv, info[] #Error code is stored in LU factorization type
52-
end
53-
function getrs!(
54-
trans::AbstractChar, A::AbstractMatrix{$elty},
55-
ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty}
56-
)
57-
require_one_based_indexing(A, ipiv, B)
58-
chktrans(trans)
59-
chkstride1(A, B, ipiv)
60-
n = checksquare(A)
61-
if n != size(B, 1)
62-
throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n"))
63-
end
64-
if n != length(ipiv)
65-
throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n"))
66-
end
67-
nrhs = size(B, 2)
68-
info = Ref{BlasInt}()
69-
ccall(
70-
(@blasfunc($getrs), libblastrampoline), Cvoid,
71-
(
72-
Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
73-
Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong,
74-
),
75-
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B,
76-
max(1, stride(B, 2)), info, 1
77-
)
78-
chklapackerror(info[])
79-
return B
80-
end
81-
end
82-
end
40+
# lda = max(1, stride(A, 2))
41+
# info = Ref{BlasInt}()
42+
# ccall(
43+
# (@blasfunc($getrf), libblastrampoline), Cvoid,
44+
# (
45+
# Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
46+
# Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
47+
# ),
48+
# m, n, A, lda, ipiv, info
49+
# )
50+
# chkargsok(info[])
51+
# return A, ipiv, info[] #Error code is stored in LU factorization type
52+
# end
53+
# function getrs!(
54+
# trans::AbstractChar, A::AbstractMatrix{$elty},
55+
# ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty}
56+
# )
57+
# require_one_based_indexing(A, ipiv, B)
58+
# chktrans(trans)
59+
# chkstride1(A, B, ipiv)
60+
# n = checksquare(A)
61+
# if n != size(B, 1)
62+
# throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n"))
63+
# end
64+
# if n != length(ipiv)
65+
# throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n"))
66+
# end
67+
# nrhs = size(B, 2)
68+
# info = Ref{BlasInt}()
69+
# ccall(
70+
# (@blasfunc($getrs), libblastrampoline), Cvoid,
71+
# (
72+
# Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
73+
# Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong,
74+
# ),
75+
# trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B,
76+
# max(1, stride(B, 2)), info, 1
77+
# )
78+
# chklapackerror(info[])
79+
# return B
80+
# end
81+
# end
82+
# end
8383

8484
# LQ, RQ, QL, and QR factorisation
8585
const DEFAULT_QR_BLOCKSIZE = 36
@@ -451,16 +451,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr
451451
k = min(mA, nA)
452452

453453
if side == 'L' && mC != mA
454-
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the first dimension of A, $mA"))
454+
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the first dimension of A, $mA"))
455455
end
456456
if side == 'R' && nC != mA
457-
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the first dimension of A, $mA"))
457+
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the first dimension of A, $mA"))
458458
end
459459
if side == 'L' && k > mC
460-
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
460+
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC"))
461461
end
462462
if side == 'R' && k > nC
463-
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n"))
463+
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC"))
464464
end
465465
lda = max(1, stride(A, 2))
466466
ldc = max(1, stride(C, 2))
@@ -503,16 +503,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr
503503
k = min(mA, nA)
504504

505505
if side == 'L' && mC != nA
506-
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $nA"))
506+
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the second dimension of A, $nA"))
507507
end
508508
if side == 'R' && nC != nA
509-
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the second dimension of A, $nA"))
509+
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the second dimension of A, $nA"))
510510
end
511511
if side == 'L' && k > mC
512-
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
512+
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC"))
513513
end
514514
if side == 'R' && k > nC
515-
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n"))
515+
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC"))
516516
end
517517
lda = max(1, stride(A, 2))
518518
ldc = max(1, stride(C, 2))
@@ -1170,6 +1170,7 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
11701170
n = checksquare(A)
11711171
chkuplofinite(A, uplo)
11721172
if haskey(kwargs, :irange)
1173+
irange = convert(UnitRange{Int}, kwargs[:irange])
11731174
il = first(irange)
11741175
iu = last(irange)
11751176
vl = vu = zero($relty)
@@ -2143,6 +2144,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21432144
m, n = size(A)
21442145
minmn = min(m, n)
21452146
if haskey(kwargs, :irange)
2147+
irange = convert(UnitRange{Int}, kwargs[:irange])
21462148
il = first(irange)
21472149
iu = last(irange)
21482150
vl = vu = zero($relty)
@@ -2276,15 +2278,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
22762278
end
22772279
end
22782280
length(S) == n ||
2279-
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
2281+
throw(DimensionMismatch("length mismatch between A ($n) and S ($(length(S)))"))
22802282

22812283
lda = max(1, stride(A, 2))
22822284
mv = Ref{BlasInt}() # unused
22832285
if jobv == 'V'
22842286
if U !== A
22852287
V = view(U, 1:n, 1:n) # use U as V storage
22862288
else
2287-
V = view(similar(V), 1:n, 1:n)
2289+
V = view(similar(Vᴴ), 1:n, 1:n)
22882290
end
22892291
else
22902292
V = Vᴴ # doesn't matter, V is not used
@@ -2342,12 +2344,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
23422344
if cmplx
23432345
if !isone(rwork[1])
23442346
@warn "singular values might have underflowed or overflowed"
2345-
LinearAlgebra.rmul!(S, rwork[1])
2347+
rmul!(S, rwork[1])
23462348
end
23472349
else
23482350
if !isone(work[1])
23492351
@warn "singular values might have underflowed or overflowed"
2350-
LinearAlgebra.rmul!(S, work[1])
2352+
rmul!(S, work[1])
23512353
end
23522354
end
23532355
if jobu == 'U' && U !== A

test/ad_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ function stabilize_eigvals!(D::AbstractVector)
3838
end
3939
n = maximum(p)
4040
# rescale eigenvalues so that they lie on distinct radii in the complex plane
41-
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
42-
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
41+
# that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n
42+
radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
4343
for i in 1:length(D)
4444
D[i] = sign(D[i]) * radii[p[i]]
4545
end

0 commit comments

Comments
 (0)