Skip to content

Commit 0ea7bfd

Browse files
committed
Fixes
1 parent 21b2382 commit 0ea7bfd

2 files changed

Lines changed: 17 additions & 1 deletion

File tree

src/implementations/svd.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ end
152152
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
153153
check_input(svd_compact!, A, USVᴴ, alg)
154154
U, S, Vᴴ = USVᴴ
155+
m, n = size(A)
156+
minmn = min(m, n)
157+
if minmn == 0
158+
one!(U)
159+
zero!(S)
160+
one!(Vᴴ)
161+
return USVᴴ
162+
end
155163

156164
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
157165
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -407,6 +415,14 @@ end
407415
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
408416
check_input(svd_compact!, A, USVᴴ, alg)
409417
U, S, Vᴴ = USVᴴ
418+
m, n = size(A)
419+
minmn = min(m, n)
420+
if minmn == 0
421+
one!(U)
422+
zero!(S)
423+
one!(Vᴴ)
424+
return USVᴴ
425+
end
410426

411427
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
412428
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})

test/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
2525
)
2626
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS; test_trunc = false)
2727
if n == m
28-
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
28+
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m)
2929
TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
3030
end
3131
end

0 commit comments

Comments
 (0)