@@ -175,21 +175,21 @@ const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
175175const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
176176const GPU_Randomized = Union{CUSOLVER_Randomized}
177177
178- function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{ CUSOLVER_Randomized} )
178+ function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized )
179179 m, n = size (A)
180180 minmn = min (m, n)
181181 U, S, Vᴴ = USVᴴ
182182 @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
183183 @check_size (U, (m, m))
184184 @check_scalar (U, A)
185- @check_size (S, (minmn,minmn))
185+ @check_size (S, (minmn, minmn))
186186 @check_scalar (S, A, real)
187187 @check_size (Vᴴ, (n, n))
188188 @check_scalar (Vᴴ, A)
189189 return nothing
190190end
191191
192- function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{CUSOLVER_Randomized} )
192+ function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<: CUSOLVER_Randomized} )
193193 m, n = size (A)
194194 minmn = min (m, n)
195195 U = similar (A, (m, m))
@@ -232,12 +232,12 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
232232end
233233
234234function svd_trunc! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
235- check_input (svd_trunc!, A, USVᴴ, alg)
235+ check_input (svd_trunc!, A, USVᴴ, alg. alg )
236236 U, S, Vᴴ = USVᴴ
237- _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. kwargs... )
237+ _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. alg . kwargs... )
238238 # TODO : make this controllable using a `gaugefix` keyword argument
239- gaugefix! (Val (:compact ), U, S, Vᴴ, m, n )
240- return truncate! (svd_trunc!, USVᴴ′ , alg. trunc)
239+ gaugefix! (Val (:trunc ), U, S, Vᴴ, size (A) ... )
240+ return truncate! (svd_trunc!, USVᴴ, alg. trunc)
241241end
242242
243243function MatrixAlgebraKit. svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
@@ -255,7 +255,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
255255 throw (ArgumentError (" Unsupported SVD algorithm" ))
256256 end
257257 # TODO : make this controllable using a `gaugefix` keyword argument
258- gaugefix! (Val (:compact ), U, S, Vᴴ, m, n)
258+ gaugefix! (Val (:compact ), U, S, Vᴴ, size (A) ... )
259259 return USVᴴ
260260end
261261_argmaxabs (x) = reduce (_largest, x; init= zero (eltype (x)))
0 commit comments