@@ -60,6 +60,27 @@ for f in [:eig_full!, :eigh_full!]
6060 end
6161end
6262
63+ for f in [:eig_trunc! , :eigh_trunc! ]
64+ @eval begin
65+ # TODO : Delete this when `select_algorithm` is generalized.
66+ function MatrixAlgebraKit. select_algorithm (:: typeof ($ f), :: Type{A} , alg;
67+ trunc= nothing ,
68+ kwargs... ) where {A<: Zeros }
69+ alg_eig = select_algorithm (eig_full!, A, alg; kwargs... )
70+ return TruncatedAlgorithm (alg_eig, select_truncation (trunc))
71+ end
72+ # TODO : I think it would be better to dispatch on the algorithm here,
73+ # rather than the output types.
74+ function MatrixAlgebraKit. truncate! (:: typeof ($ f), (D, V):: Tuple{Zeros,Eye} ,
75+ strategy:: TruncationStrategy )
76+ ind = findtruncated (diagview (D), strategy)
77+ D′ = D[ind, ind]
78+ V′ = Eye ((axes (V, 1 ), only (axes (axes (V, 2 )[ind]))))
79+ return (D′, V′)
80+ end
81+ end
82+ end
83+
6384for f in [:eig_vals! , :eigh_vals! ]
6485 @eval begin
6586 function MatrixAlgebraKit. check_input (:: typeof ($ f), A:: AbstractZerosMatrix , F)
@@ -127,6 +148,24 @@ function MatrixAlgebraKit.svd_full!(A::AbstractZerosMatrix, F, alg::ZerosAlgorit
127148 return (Eye ((ax[1 ], ax[1 ])), Zeros (ax), Eye ((ax[2 ], ax[2 ])))
128149end
129150
151+ # TODO : Delete this when `select_algorithm` is generalized.
152+ function MatrixAlgebraKit. select_algorithm (:: typeof (svd_trunc!), :: Type{A} , alg;
153+ trunc= nothing ,
154+ kwargs... ) where {A<: Zeros }
155+ alg_eig = select_algorithm (eig_full!, A, alg; kwargs... )
156+ return TruncatedAlgorithm (alg_eig, select_truncation (trunc))
157+ end
158+ # TODO : I think it would be better to dispatch on the algorithm here,
159+ # rather than the output types.
160+ function MatrixAlgebraKit. truncate! (:: typeof (svd_trunc!), (U, S, V):: Tuple{Eye,Zeros,Eye} ,
161+ strategy:: TruncationStrategy )
162+ ind = findtruncated (diagview (S), strategy)
163+ U′ = Eye ((axes (U, 1 ), only (axes (axes (U, 2 )[ind]))))
164+ S′ = S[ind, ind]
165+ V′ = Eye ((only (axes (axes (V, 1 )[ind])), axes (V, 2 )))
166+ return (U′, S′, V′)
167+ end
168+
130169function MatrixAlgebraKit. svd_vals! (A:: AbstractZerosMatrix , F, alg:: ZerosAlgorithm )
131170 return diagview (A)
132171end
@@ -193,8 +232,9 @@ for f in [:eig_trunc!, :eigh_trunc!]
193232 function MatrixAlgebraKit. truncate! (:: typeof ($ f), (D, V):: Tuple{Eye,Eye} ,
194233 strategy:: TruncationStrategy )
195234 ind = findtruncated (diagview (D), strategy)
196- return Diagonal (diagview (D)[ind]),
197- Eye ((axes (V, 1 ), only (axes (axes (V, 2 )[ind]))))
235+ D′ = Diagonal (diagview (D)[ind])
236+ U′ = Eye ((axes (V, 1 ), only (axes (axes (V, 2 )[ind]))))
237+ return (D′, U′)
198238 end
199239 end
200240end
0 commit comments