@@ -105,133 +105,106 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm
105105 return eltype (A) <: Real ? diagview (A) : similar (A, real (eltype (A)), size (A, 1 ))
106106end
107107
108- # Implementation
109- # --------------
110- function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: LAPACK_SVDAlgorithm )
111- check_input (svd_full!, A, USVᴴ, alg)
112- U, S, Vᴴ = USVᴴ
113- fill! (S, zero (eltype (S)))
114- m, n = size (A)
115- minmn = min (m, n)
116- if minmn == 0
117- one! (U)
118- zero! (S)
119- one! (Vᴴ)
120- return USVᴴ
121- end
122-
123- do_gauge_fix = get (alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
124- alg_kwargs = Base. structdiff (alg. kwargs, NamedTuple{(:fixgauge ,)})
125-
126- if alg isa LAPACK_QRIteration
127- isempty (alg_kwargs) ||
128- throw (ArgumentError (" invalid keyword arguments for LAPACK_QRIteration" ))
129- YALAPACK. gesvd! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
130- elseif alg isa LAPACK_DivideAndConquer
131- isempty (alg_kwargs) ||
132- throw (ArgumentError (" invalid keyword arguments for LAPACK_DivideAndConquer" ))
133- YALAPACK. gesdd! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
134- elseif alg isa LAPACK_SafeDivideAndConquer
135- isempty (alg_kwargs) ||
136- throw (ArgumentError (" invalid keyword arguments for LAPACK_SafeDivideAndConquer" ))
137- YALAPACK. gesdvd! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
138- elseif alg isa LAPACK_Bisection
139- throw (ArgumentError (" LAPACK_Bisection is not supported for full SVD" ))
140- elseif alg isa LAPACK_Jacobi
141- throw (ArgumentError (" LAPACK_Jacobi is not supported for full SVD" ))
142- else
143- throw (ArgumentError (" Unsupported SVD algorithm" ))
144- end
145-
146- for i in 2 : minmn
147- S[i, i] = S[i, 1 ]
148- S[i, 1 ] = zero (eltype (S))
149- end
150-
151- do_gauge_fix && gaugefix! (svd_full!, U, Vᴴ)
108+ # ==========================
109+ # IMPLEMENTATIONS
110+ # ==========================
152111
153- return USVᴴ
112+ for f! in (:gesdd! , :gesvd! , :gesvdj! , :gesvdp! , :gesvdx! , :gesvdr! , :gesdvd! )
113+ @eval $ f! (driver:: Driver , args... ) = throw (ArgumentError (" $driver does not provide $f! " ))
154114end
155115
156- function svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: LAPACK_SVDAlgorithm )
157- check_input (svd_compact!, A, USVᴴ, alg)
158- U, S, Vᴴ = USVᴴ
159- m, n = size (A)
160- minmn = min (m, n)
161- if minmn == 0
162- one! (U)
163- zero! (S)
164- one! (Vᴴ)
165- return USVᴴ
166- end
167-
168- do_gauge_fix = get (alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
169- alg_kwargs = Base. structdiff (alg. kwargs, NamedTuple{(:fixgauge ,)})
170-
171- if alg isa LAPACK_QRIteration
172- isempty (alg_kwargs) ||
173- throw (ArgumentError (" invalid keyword arguments for LAPACK_QRIteration" ))
174- YALAPACK. gesvd! (A, diagview (S), U, Vᴴ)
175- elseif alg isa LAPACK_DivideAndConquer
176- isempty (alg_kwargs) ||
177- throw (ArgumentError (" invalid keyword arguments for LAPACK_DivideAndConquer" ))
178- YALAPACK. gesdd! (A, diagview (S), U, Vᴴ)
179- elseif alg isa LAPACK_SafeDivideAndConquer
180- isempty (alg_kwargs) ||
181- throw (ArgumentError (" invalid keyword arguments for LAPACK_SafeDivideAndConquer" ))
182- YALAPACK. gesdvd! (A, diagview (S), U, Vᴴ)
183- elseif alg isa LAPACK_Bisection
184- YALAPACK. gesvdx! (A, diagview (S), U, Vᴴ; alg_kwargs... )
185- elseif alg isa LAPACK_Jacobi
186- isempty (alg_kwargs) ||
187- throw (ArgumentError (" invalid keyword arguments for LAPACK_Jacobi" ))
188- YALAPACK. gesvj! (A, diagview (S), U, Vᴴ)
189- else
190- throw (ArgumentError (" Unsupported SVD algorithm" ))
191- end
192-
193- do_gauge_fix && gaugefix! (svd_compact!, U, Vᴴ)
194-
195- return USVᴴ
116+ # LAPACK
117+ for f! in (:gesdd! , :gesvd! , :gesvdj! , :gesvdx! , :gesdvd! )
118+ @eval $ f! (:: LAPACK , args... ; kwargs... ) = YALAPACK.$ f! (args... ; kwargs... )
196119end
197120
198- function svd_vals! (A:: AbstractMatrix , S, alg:: LAPACK_SVDAlgorithm )
199- check_input (svd_vals!, A, S, alg)
200- m, n = size (A)
201- minmn = min (m, n)
202- if minmn == 0
203- zero! (S)
204- return S
121+ for (f, f_lapack!, Alg) in (
122+ (:safe_divide_and_conquer , :gesdvd! , :SafeDivideAndConquer ),
123+ (:divide_and_conquer , :gesdd! , :DivideAndConquer ),
124+ (:qr_iteration , :gesvd! , :QRIteration ),
125+ (:bisection , :gesvdx! , :Bisection ),
126+ (:jacobi , :gesvdj! , :Jacobi ),
127+ )
128+ f_svd! = Symbol (f, :_svd! )
129+ f_svd_full! = Symbol (f, :_svd_full! )
130+ f_svd_vals! = Symbol (f, :_svd_vals! )
131+
132+ # MatrixAlgebraKit wrappers
133+ @eval begin
134+ function svd_compact! (A, USVᴴ, alg:: $Alg )
135+ check_input (svd_compact!, A, USVᴴ, alg)
136+ return $ f_svd! (A, USVᴴ... ; alg. kwargs... )
137+ end
138+ function svd_full! (A, USVᴴ, alg:: $Alg )
139+ check_input (svd_full!, A, USVᴴ, alg)
140+ return $ f_svd_full! (A, USVᴴ... ; alg. kwargs... )
141+ end
142+ function svd_vals! (A, S, alg:: $Alg )
143+ check_input (svd_vals!, A, S, alg)
144+ return $ f_svd_vals! (A, S; alg. kwargs... )
145+ end
205146 end
206- U, Vᴴ = similar (A, (0 , 0 )), similar (A, (0 , 0 ))
207-
208- alg_kwargs = Base. structdiff (alg. kwargs, NamedTuple{(:fixgauge ,)})
209147
210- if alg isa LAPACK_QRIteration
211- isempty (alg_kwargs) ||
212- throw (ArgumentError (" invalid keyword arguments for LAPACK_QRIteration" ))
213- YALAPACK. gesvd! (A, S, U, Vᴴ)
214- elseif alg isa LAPACK_DivideAndConquer
215- isempty (alg_kwargs) ||
216- throw (ArgumentError (" invalid keyword arguments for LAPACK_DivideAndConquer" ))
217- YALAPACK. gesdd! (A, S, U, Vᴴ)
218- elseif alg isa LAPACK_SafeDivideAndConquer
219- isempty (alg_kwargs) ||
220- throw (ArgumentError (" invalid keyword arguments for LAPACK_SafeDivideAndConquer" ))
221- YALAPACK. gesdvd! (A, S, U, Vᴴ)
222- elseif alg isa LAPACK_Bisection
223- YALAPACK. gesvdx! (A, S, U, Vᴴ; alg_kwargs... )
224- elseif alg isa LAPACK_Jacobi
225- isempty (alg_kwargs) ||
226- throw (ArgumentError (" invalid keyword arguments for LAPACK_Jacobi" ))
227- YALAPACK. gesvj! (A, S, U, Vᴴ)
228- else
229- throw (ArgumentError (" Unsupported SVD algorithm" ))
148+ # driver
149+ @eval begin
150+ @inline $ f_svd! (A, U, S, Vᴴ; driver:: Driver = DefaultDriver (), kwargs... ) =
151+ $ f_svd! (driver, A, U, S, Vᴴ; kwargs... )
152+ @inline $ f_svd_full! (A, U, S, Vᴴ; driver:: Driver = DefaultDriver (), kwargs... ) =
153+ $ f_svd_full! (driver, A, U, S, Vᴴ; kwargs... )
154+ @inline $ f_svd_vals! (A, S; driver:: Driver = DefaultDriver (), kwargs... ) =
155+ $ f_svd_vals! (driver, A, S; kwargs... )
156+ @inline $ f_svd! (:: DefaultDriver , A, U, S, Vᴴ; kwargs... ) =
157+ $ f_svd! ($ (Symbol (:default_ , f, :_driver ))(A), A, U, S, Vᴴ; kwargs... )
158+ @inline $ f_svd_full! (:: DefaultDriver , A, S; kwargs... ) =
159+ $ f_svd_full! ($ (Symbol (:default_ , f, :_driver )), A, S; kwargs... )
160+ @inline $ f_svd_vals! (:: DefaultDriver , A, S; kwargs... ) =
161+ $ f_svd_vals! ($ (Symbol (:default_ , f, :_driver )), A, S; kwargs... )
230162 end
231163
232- return S
164+ # Implementation
165+ @eval begin
166+ function $f_svd! (
167+ driver:: Driver , A:: AbstractMatrix , U:: AbstractMatrix , S:: AbstractMatrix , Vᴴ:: AbstractMatrix ;
168+ fixgauge:: Bool = true , kwargs...
169+ )
170+ supports_svd (driver, $ (QuoteNode (f))) || throw (ArgumentError (lazy " $driver does not provide $f" ))
171+ isempty (A) && return one! (U), zero! (S), one! (Vᴴ)
172+ $ f_lapack! (driver, A, view (S, 1 : minmn, 1 ), U, Vᴴ; kwargs... )
173+ fixgauge && gaugefix! (svd_compact!, U, Vᴴ)
174+ return U, S, Vᴴ
175+ end
176+ function $f_svd_full! (
177+ driver:: Driver , A:: AbstractMatrix , U:: AbstractMatrix , S:: AbstractMatrix , Vᴴ:: AbstractMatrix ;
178+ fixgauge:: Bool = true , kwargs...
179+ )
180+ supports_svd_full (driver, $ (QuoteNode (f))) || throw (ArgumentError (lazy " $driver does not provide $f" ))
181+ isempty (A) && return one! (U), zero! (S), one! (Vᴴ)
182+ zero! (S)
183+ minmn = min (size (A)... )
184+ $ f_lapack! (driver, A, view (S, 1 : minmn, 1 ), U, Vᴴ; kwargs... )
185+ diagview (S) .= view (S, 1 : minmn, 1 )
186+ view (S, 2 : minmn, 1 ) .= zero (eltype (S))
187+ fixgauge && gaugefix! (svd_full!, U, Vᴴ)
188+ return U, S, Vᴴ
189+ end
190+ function $f_svd_vals! (
191+ driver:: Driver , A:: AbstractMatrix , S:: AbstractVector ;
192+ fixgauge:: Bool = true , kwargs...
193+ )
194+ supports_svd (driver, $ (QuoteNode (f))) || throw (ArgumentError (lazy " $driver does not provide $f" ))
195+ isempty (A) && return zero! (S)
196+ U, Vᴴ = similar (A, (0 , 0 )), similar (A, (0 , 0 ))
197+ $ f_lapack! (driver, A, view (S, 1 : minmn, 1 ), U, Vᴴ; kwargs... )
198+ return S
199+ end
200+ end
233201end
234202
203+ supports_svd (:: Driver , :: Symbol ) = false
204+ supports_svd (:: LAPACK , f:: Symbol ) = f in (:safe_divide_and_conquer , :divide_and_conquer , :qr_iteration , :bisection , :jacobi )
205+ supports_svd_full (:: Driver , :: Symbol ) = false
206+ supports_svd_full (:: LAPACK , f:: Symbol ) = f in (:safe_divide_and_conquer , :divide_and_conquer , :qr_iteration )
207+
235208function svd_trunc_no_error! (A, USVᴴ, alg:: TruncatedAlgorithm )
236209 U, S, Vᴴ = svd_compact! (A, USVᴴ, alg. alg)
237210 USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
@@ -485,3 +458,23 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
485458
486459 return S
487460end
461+
462+ # Deprecations
463+ # ------------
464+ for algtype in (:DivideAndConquer , :QRIteration , :Jacobi , :Bisection )
465+ algtype = Symbol (:LAPACK_ , algtype)
466+ @eval begin
467+ Base. @deprecate (
468+ svd_compact! (A, USVᴴ, alg:: $algtype ),
469+ svd_compact! (A, USVᴴ, $ algtype (; driver = LAPACK (), alg. kwargs... ))
470+ )
471+ Base. @deprecate (
472+ svd_full! (A, USVᴴ, alg:: $algtype ),
473+ svd_full! (A, USVᴴ, $ algtype (; driver = LAPACK (), alg. kwargs... ))
474+ )
475+ Base. @deprecate (
476+ svd_vals! (A, S, alg:: $algtype ),
477+ svd_vals! (A, S, $ algtype (; driver = LAPACK (), alg. kwargs... ))
478+ )
479+ end
480+ end
0 commit comments