Skip to content

Commit e4f7c8e

Browse files
committed
Small changes to unblock TensorKit svd
1 parent 1f5fea2 commit e4f7c8e

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT<:
2222
return LQViaTransposedQR(qr_alg)
2323
end
2424
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
25-
return CUSOLVER_QRIteration(; kwargs...)
25+
return CUSOLVER_Jacobi(; kwargs...)
2626
end
2727
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
2828
return CUSOLVER_Simple(; kwargs...)

src/implementations/polar.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
5050
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
5151
W, P = WP
5252
W = mul!(W, U, Vᴴ)
53-
S .= sqrt.(S)
54-
SsqrtVᴴ = lmul!(S, Vᴴ)
53+
@. S = sqrt(S)
54+
@. Vᴴ *= S
55+
SsqrtVᴴ = Vᴴ
5556
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
5657
return (W, P)
5758
end
@@ -60,8 +61,9 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
6061
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
6162
P, Wᴴ = PWᴴ
6263
Wᴴ = mul!(Wᴴ, U, Vᴴ)
63-
S .= sqrt.(S)
64-
USsqrt = rmul!(U, S)
64+
@. S = sqrt(S)
65+
@. U *= S
66+
USsqrt = U
6567
P = mul!(P, USsqrt, USsqrt')
6668
return (P, Wᴴ)
6769
end

src/implementations/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
288288
U, S, Vᴴ = USVᴴ
289289
if alg isa GPU_QRIteration
290290
isempty(alg.kwargs) ||
291-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
291+
@warn "GPU_QRIteration does not accept any keyword arguments"
292292
_gpu_gesvd!(A, S.diag, U, Vᴴ)
293293
elseif alg isa GPU_SVDPolar
294294
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)

0 commit comments

Comments
 (0)