Skip to content

Commit ed8d2bd

Browse files
committed
More support for truncation
1 parent 880e8d5 commit ed8d2bd

2 files changed

Lines changed: 69 additions & 2 deletions

File tree

ext/MatrixAlgebraKitFillArraysExt.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,27 @@ for f in [:eig_full!, :eigh_full!]
6060
end
6161
end
6262

63+
for f in [:eig_trunc!, :eigh_trunc!]
64+
@eval begin
65+
# TODO: Delete this when `select_algorithm` is generalized.
66+
function MatrixAlgebraKit.select_algorithm(::typeof($f), ::Type{A}, alg;
67+
trunc=nothing,
68+
kwargs...) where {A<:Zeros}
69+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
70+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
71+
end
72+
# TODO: I think it would be better to dispatch on the algorithm here,
73+
# rather than the output types.
74+
function MatrixAlgebraKit.truncate!(::typeof($f), (D, V)::Tuple{Zeros,Eye},
75+
strategy::TruncationStrategy)
76+
ind = findtruncated(diagview(D), strategy)
77+
D′ = D[ind, ind]
78+
V′ = Eye((axes(V, 1), only(axes(axes(V, 2)[ind]))))
79+
return (D′, V′)
80+
end
81+
end
82+
end
83+
6384
for f in [:eig_vals!, :eigh_vals!]
6485
@eval begin
6586
function MatrixAlgebraKit.check_input(::typeof($f), A::AbstractZerosMatrix, F)
@@ -127,6 +148,24 @@ function MatrixAlgebraKit.svd_full!(A::AbstractZerosMatrix, F, alg::ZerosAlgorit
127148
return (Eye((ax[1], ax[1])), Zeros(ax), Eye((ax[2], ax[2])))
128149
end
129150

151+
# TODO: Delete this when `select_algorithm` is generalized.
152+
function MatrixAlgebraKit.select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg;
153+
trunc=nothing,
154+
kwargs...) where {A<:Zeros}
155+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
156+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
157+
end
158+
# TODO: I think it would be better to dispatch on the algorithm here,
159+
# rather than the output types.
160+
function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, V)::Tuple{Eye,Zeros,Eye},
161+
strategy::TruncationStrategy)
162+
ind = findtruncated(diagview(S), strategy)
163+
U′ = Eye((axes(U, 1), only(axes(axes(U, 2)[ind]))))
164+
S′ = S[ind, ind]
165+
V′ = Eye((only(axes(axes(V, 1)[ind])), axes(V, 2)))
166+
return (U′, S′, V′)
167+
end
168+
130169
function MatrixAlgebraKit.svd_vals!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm)
131170
return diagview(A)
132171
end
@@ -193,8 +232,9 @@ for f in [:eig_trunc!, :eigh_trunc!]
193232
function MatrixAlgebraKit.truncate!(::typeof($f), (D, V)::Tuple{Eye,Eye},
194233
strategy::TruncationStrategy)
195234
ind = findtruncated(diagview(D), strategy)
196-
return Diagonal(diagview(D)[ind]),
197-
Eye((axes(V, 1), only(axes(axes(V, 2)[ind]))))
235+
D′ = Diagonal(diagview(D)[ind])
236+
U′ = Eye((axes(V, 1), only(axes(axes(V, 2)[ind]))))
237+
return (D′, U′)
198238
end
199239
end
200240
end

test/fillarrays.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ using FillArrays: SquareEye
2020
end
2121
end
2222

23+
for f in [:eig_trunc, :eigh_trunc]
24+
@eval begin
25+
A = Zeros(3, 3)
26+
D, V = @constinferred $f(A; trunc=(; maxrank=2))
27+
@test A * V == V * D
28+
@test size(D) == (2, 2)
29+
@test size(V) == (3, 2)
30+
@test D == Zeros(2, 2)
31+
@test D isa Zeros
32+
@test V == Eye(3, 2)
33+
@test V isa Eye
34+
end
35+
end
36+
2337
for f in [:eig_vals, :eigh_vals]
2438
@eval begin
2539
A = Zeros(3, 3)
@@ -131,6 +145,19 @@ using FillArrays: SquareEye
131145
@test V == I
132146
@test V isa Eye
133147

148+
A = Zeros(3, 4)
149+
U, S, V = @constinferred svd_trunc(A; trunc=(; maxrank=2))
150+
@test U * S * V == Eye(3, 2) * Zeros(2, 2) * Eye(2, 4)
151+
@test size(U) == (3, 2)
152+
@test size(S) == (2, 2)
153+
@test size(V) == (2, 4)
154+
@test S == Zeros(2, 2)
155+
@test S isa Zeros
156+
@test U == Eye(3, 2)
157+
@test U isa Eye
158+
@test V == Eye(2, 4)
159+
@test V isa Eye
160+
134161
A = Zeros(3, 4)
135162
D = @constinferred svd_vals(A)
136163
@test size(D) == (minimum(size(A)),)

0 commit comments

Comments
 (0)