Skip to content

Commit ab13437

Browse files
committed
Better handling of TruncatedAlgorithm in select_algorithm
1 parent 0cf1820 commit ab13437

7 files changed

Lines changed: 75 additions & 6 deletions

File tree

src/interface/eig.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ for f in (:eig_full!, :eig_vals!)
100100
end
101101

102102
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
103-
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
104-
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
103+
if alg isa TruncatedAlgorithm
104+
isnothing(trunc) ||
105+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
106+
return alg
107+
else
108+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
109+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
110+
end
105111
end

src/interface/eigh.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ for f in (:eigh_full!, :eigh_vals!)
101101
end
102102

103103
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
104-
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
105-
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
104+
if alg isa TruncatedAlgorithm
105+
isnothing(trunc) ||
106+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
107+
return alg
108+
else
109+
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
110+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
111+
end
106112
end

src/interface/svd.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
105105
end
106106

107107
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
108-
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
109-
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
108+
if alg isa TruncatedAlgorithm
109+
isnothing(trunc) ||
110+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
111+
return alg
112+
else
113+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
114+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
115+
end
110116
end

test/algorithms.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ end
5050
NoTruncation())
5151
end
5252

53+
alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0))
54+
for f in (eig_trunc!, eigh_trunc!, svd_trunc!)
55+
@test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg
56+
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2))
57+
end
58+
5359
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
5460
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) ===
5561
LAPACK_DivideAndConquer()

test/eig.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,17 @@ end
5757
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
5858
end
5959
end
60+
61+
@testset "eig_trunc! specify truncation algorithm T = $T" for T in
62+
(Float32, Float64, ComplexF32,
63+
ComplexF64)
64+
rng = StableRNG(123)
65+
m = 4
66+
V = qr_compact(randn(rng, T, m, m))[1]
67+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
68+
A = V * D * V'
69+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
70+
D2, V2 = @constinferred eig_trunc(A; alg)
71+
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
72+
@test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2))
73+
end

test/eigh.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,19 @@ end
6262
@test V2 * (V2' * V1) V1
6363
end
6464
end
65+
66+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
67+
(Float32, Float64,
68+
ComplexF32,
69+
ComplexF64)
70+
rng = StableRNG(123)
71+
m = 4
72+
V = qr_compact(randn(rng, T, m, m))[1]
73+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
74+
A = V * D * V'
75+
A = (A + A') / 2
76+
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2))
77+
D2, V2 = @constinferred eigh_trunc(A; alg)
78+
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
79+
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
80+
end

test/svd.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,18 @@ end
152152
end
153153
end
154154
end
155+
156+
@testset "svd_trunc! specify truncation algorithm T = $T" for T in
157+
(Float32, Float64, ComplexF32,
158+
ComplexF64)
159+
rng = StableRNG(123)
160+
m = 4
161+
U = qr_compact(randn(rng, T, m, m))[1]
162+
S = Diagonal([0.9, 0.3, 0.1, 0.01])
163+
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
164+
A = U * S * Vᴴ
165+
alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), TruncationKeepAbove(0.2, 0.0))
166+
U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg)
167+
@test diagview(S2) diagview(S)[1:2] rtol = sqrt(eps(real(T)))
168+
@test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2))
169+
end

0 commit comments

Comments
 (0)