Skip to content

Commit 3f7d45d

Browse files
committed
another attempt
1 parent 2388634 commit 3f7d45d

2 files changed

Lines changed: 81 additions & 10 deletions

File tree

src/implementations/svd.jl

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
134134
elseif alg isa LAPACK_SafeDivideAndConquer
135135
isempty(alg_kwargs) ||
136136
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
137-
YALAPACK.gesdvd!(A, copy_input(svd_full, A), view(S, 1:minmn, 1), U, Vᴴ)
137+
YALAPACK.gesdvd!(A, copy(A), view(S, 1:minmn, 1), U, Vᴴ)
138138
elseif alg isa LAPACK_Bisection
139139
throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
140140
elseif alg isa LAPACK_Jacobi
@@ -179,7 +179,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
179179
elseif alg isa LAPACK_SafeDivideAndConquer
180180
isempty(alg_kwargs) ||
181181
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
182-
YALAPACK.gesdvd!(A, copy_input(svd_compact, A), diagview(S), U, Vᴴ)
182+
YALAPACK.gesdvd!(A, copy(A), diagview(S), U, Vᴴ)
183183
elseif alg isa LAPACK_Bisection
184184
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
185185
elseif alg isa LAPACK_Jacobi
@@ -218,7 +218,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
218218
elseif alg isa LAPACK_SafeDivideAndConquer
219219
isempty(alg_kwargs) ||
220220
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
221-
YALAPACK.gesdvd!(A, copy_input(svd_vals, A), S, U, Vᴴ)
221+
YALAPACK.gesdvd!(A, copy(A), S, U, Vᴴ)
222222
elseif alg isa LAPACK_Bisection
223223
YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...)
224224
elseif alg isa LAPACK_Jacobi
@@ -233,21 +233,92 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
233233
end
234234

235235
# avoid double allocation
236-
for f in (:svd_compact, :svd_full, :svd_vals)
237-
f! = Symbol(f, :!)
238-
@eval $f(A, alg::LAPACK_SafeDivideAndConquer) = $f!(A, alg)
236+
function svd_full(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer)
237+
Ac = copy_input(svd_full, A)
238+
USVᴴ = initialize_output(svd_full!, Ac, alg)
239+
check_input(svd_full!, Ac, USVᴴ, alg)
240+
241+
U, S, Vᴴ = USVᴴ
242+
zero!(S)
243+
244+
minmn = min(size(A)...)
245+
minmn == 0 && return one!(U), S, one!(Vᴴ)
246+
247+
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
248+
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
249+
isempty(alg_kwargs) ||
250+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
251+
252+
253+
YALAPACK.gesdvd!(A, Ac, view(S, 1:minmn, 1), U, Vᴴ)
254+
255+
for i in 2:minmn
256+
S[i, i] = S[i, 1]
257+
S[i, 1] = zero(eltype(S))
258+
end
259+
260+
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
261+
262+
return USVᴴ
263+
end
264+
function svd_compact(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer)
265+
Ac = copy_input(svd_compact, A)
266+
USVᴴ = initialize_output(svd_compact!, Ac, alg)
267+
check_input(svd_compact!, Ac, USVᴴ, alg)
268+
269+
U, S, Vᴴ = USVᴴ
270+
zero!(S)
271+
272+
minmn = min(size(A)...)
273+
minmn == 0 && return one!(U), S, one!(Vᴴ)
274+
275+
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
276+
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
277+
isempty(alg_kwargs) ||
278+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
279+
280+
YALAPACK.gesdvd!(A, Ac, diagview(S), U, Vᴴ)
281+
282+
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
283+
284+
return USVᴴ
239285
end
240-
for f in (:svd_trunc, :svd_trunc_no_error)
241-
f! = Symbol(f, :!)
242-
@eval $f(A, alg::TruncatedAlgorithm{<:LAPACK_SafeDivideAndConquer}) = $f!(A, alg)
286+
function svd_vals(A::AbstractMatrix, alg::LAPACK_SVDAlgorithm)
287+
Ac = copy_input(svd_vals, A)
288+
S = initialize_output(svd_vals!, Ac, alg)
289+
check_input(svd_vals!, Ac, S, alg)
290+
291+
minmn = min(size(A)...)
292+
minmn == 0 && return zero!(S)
293+
294+
U, Vᴴ = similar(Ac, (0, 0)), similar(Ac, (0, 0))
295+
296+
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
297+
isempty(alg_kwargs) ||
298+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
299+
300+
YALAPACK.gesdvd!(A, Ac, S, U, Vᴴ)
301+
302+
return S
243303
end
244304

305+
function svd_trunc_no_error(A, alg::TruncatedAlgorithm)
306+
USVᴴ = svd_compact(A, alg.alg)
307+
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
308+
return USVᴴtrunc
309+
end
245310
function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm)
246311
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
247312
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
248313
return USVᴴtrunc
249314
end
250315

316+
function svd_trunc(A, alg::TruncatedAlgorithm)
317+
USVᴴ = svd_compact(A, alg.alg)
318+
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
319+
ϵ = truncation_error!(diagview(USVᴴ[2]), ind)
320+
return USVᴴtrunc..., ϵ
321+
end
251322
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
252323
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
253324
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)

src/yalapack.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2420,7 +2420,7 @@ function gesdvd!(
24202420
else
24212421
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
24222422
end
2423-
rwork = Vector{T}(undef, lrwork)
2423+
rwork = Vector{Tr}(undef, lrwork)
24242424
else
24252425
rwork = nothing
24262426
end

0 commit comments

Comments
 (0)