Skip to content

Commit 62ca99d

Browse files
committed
Add qr_ implementations
1 parent 50b5246 commit 62ca99d

3 files changed

Lines changed: 676 additions & 0 deletions

File tree

src/tensors/matrixalgebrakit.jl

Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,396 @@ function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloa
223223
return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...),
224224
scheduler)
225225
end
226+
227+
# QR decomposition
228+
# ----------------
229+
function MatrixAlgebraKit.check_input(::typeof(qr_full!), t::AbstractTensorMap,
230+
(Q,
231+
R)::Tuple{<:AbstractTensorMap,<:AbstractTensorMap})
232+
# scalartype checks
233+
@check_eltype Q t
234+
@check_eltype R t
235+
236+
# space checks
237+
V_Q = fuse(codomain(t))
238+
space(Q) == (codomain(t) V_Q) ||
239+
throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(Q) == (codomain(t) ← fuse(codomain(t)))`"))
240+
space(R) == (V_Q domain(t)) ||
241+
throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(R) == (fuse(codomain(t)) ← domain(t)`"))
242+
243+
return nothing
244+
end
245+
246+
function MatrixAlgebraKit.check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R))
247+
# scalartype checks
248+
@check_eltype Q t
249+
@check_eltype R t
250+
251+
# space checks
252+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
253+
space(Q) == (codomain(t) V_Q) ||
254+
throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(Q) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`"))
255+
space(R) == (V_Q domain(t)) ||
256+
throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(R) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`"))
257+
258+
return nothing
259+
end
260+
261+
function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap,
262+
N::AbstractTensorMap)
263+
# scalartype checks
264+
@check_eltype N t
265+
266+
# space checks
267+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
268+
V_N = setdiff(fuse(codomain(t)), V_Q)
269+
space(N) == (codomain(t) V_N) ||
270+
throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
271+
272+
return nothing
273+
end
274+
275+
function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), t::AbstractTensorMap,
276+
::MatrixAlgebraKit.AbstractAlgorithm)
277+
V_Q = fuse(codomain(t))
278+
Q = similar(t, codomain(t) V_Q)
279+
R = similar(t, V_Q domain(t))
280+
return Q, R
281+
end
282+
283+
function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), t::AbstractTensorMap,
284+
::MatrixAlgebraKit.AbstractAlgorithm)
285+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
286+
Q = similar(t, codomain(t) V_Q)
287+
R = similar(t, V_Q domain(t))
288+
return Q, R
289+
end
290+
291+
function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), t::AbstractTensorMap,
292+
::MatrixAlgebraKit.AbstractAlgorithm)
293+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
294+
V_N = setdiff(fuse(codomain(t)), V_Q)
295+
N = similar(t, codomain(t) V_N)
296+
return N
297+
end
298+
299+
function MatrixAlgebraKit.qr_full!(t::AbstractTensorMap, (Q, R),
300+
alg::BlockAlgorithm)
301+
MatrixAlgebraKit.check_input(qr_full!, t, (Q, R))
302+
303+
foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r)
304+
q′, r′ = qr_full!(b, (q, r), alg.alg)
305+
# deal with the case where the output is not the same as the input
306+
q === q′ || copyto!(q, q′)
307+
r === r′ || copyto!(r, r′)
308+
return nothing
309+
end
310+
311+
return Q, R
312+
end
313+
314+
function MatrixAlgebraKit.qr_compact!(t::AbstractTensorMap, (Q, R),
315+
alg::BlockAlgorithm)
316+
MatrixAlgebraKit.check_input(qr_compact!, t, (Q, R))
317+
318+
foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r)
319+
q′, r′ = qr_compact!(b, (q, r), alg.alg)
320+
# deal with the case where the output is not the same as the input
321+
q === q′ || copyto!(q, q′)
322+
r === r′ || copyto!(r, r′)
323+
return nothing
324+
end
325+
326+
return Q, R
327+
end
328+
329+
function MatrixAlgebraKit.qr_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm)
330+
MatrixAlgebraKit.check_input(qr_null!, t, N)
331+
332+
foreachblock(t, N; alg.scheduler) do _, (b, n)
333+
n′ = qr_null!(b, n, alg.alg)
334+
# deal with the case where the output is not the same as the input
335+
n === n′ || copyto!(n, n′)
336+
return nothing
337+
end
338+
339+
return N
340+
end
341+
342+
function MatrixAlgebraKit.default_qr_algorithm(t::AbstractTensorMap{<:BlasFloat};
343+
scheduler=default_blockscheduler(t),
344+
kwargs...)
345+
return BlockAlgorithm(LAPACK_HouseholderQR(; kwargs...), scheduler)
346+
end
347+
348+
# LQ decomposition
349+
# ----------------
350+
function MatrixAlgebraKit.check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q))
351+
# scalartype checks
352+
@check_eltype L t
353+
@check_eltype Q t
354+
355+
# space checks
356+
V_Q = fuse(domain(t))
357+
space(L) == (codomain(t) V_Q) ||
358+
throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(L) == (codomain(t) ← fuse(domain(t)))`"))
359+
space(Q) == (V_Q domain(t)) ||
360+
throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(Q) == (fuse(domain(t)) ← domain(t))`"))
361+
362+
return nothing
363+
end
364+
365+
function MatrixAlgebraKit.check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q))
366+
# scalartype checks
367+
@check_eltype L t
368+
@check_eltype Q t
369+
370+
# space checks
371+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
372+
space(L) == (codomain(t) V_Q) ||
373+
throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(L) == infimum(fuse(codomain(t)), fuse(domain(t)))`"))
374+
space(Q) == (V_Q domain(t)) ||
375+
throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(Q) == infimum(fuse(codomain(t)), fuse(domain(t)))`"))
376+
377+
return nothing
378+
end
379+
380+
function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap, N)
381+
# scalartype checks
382+
@check_eltype N t
383+
384+
# space checks
385+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
386+
V_N = setdiff(fuse(domain(t)), V_Q)
387+
space(N) == (V_N domain(t)) ||
388+
throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == setdiff(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`"))
389+
390+
return nothing
391+
end
392+
393+
function MatrixAlgebraKit.initialize_output(::typeof(lq_full!), t::AbstractTensorMap,
394+
::MatrixAlgebraKit.AbstractAlgorithm)
395+
V_Q = fuse(domain(t))
396+
L = similar(t, codomain(t) V_Q)
397+
Q = similar(t, V_Q domain(t))
398+
return L, Q
399+
end
400+
401+
function MatrixAlgebraKit.initialize_output(::typeof(lq_compact!), t::AbstractTensorMap,
402+
::MatrixAlgebraKit.AbstractAlgorithm)
403+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
404+
L = similar(t, codomain(t) V_Q)
405+
Q = similar(t, V_Q domain(t))
406+
return L, Q
407+
end
408+
409+
function MatrixAlgebraKit.initialize_output(::typeof(lq_null!), t::AbstractTensorMap,
410+
::MatrixAlgebraKit.AbstractAlgorithm)
411+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
412+
V_N = setdiff(fuse(domain(t)), V_Q)
413+
N = similar(t, V_N domain(t))
414+
return N
415+
end
416+
417+
function MatrixAlgebraKit.lq_full!(t::AbstractTensorMap, (L, Q),
418+
alg::BlockAlgorithm)
419+
MatrixAlgebraKit.check_input(lq_full!, t, (L, Q))
420+
421+
foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q)
422+
l′, q′ = lq_full!(b, (l, q), alg.alg)
423+
# deal with the case where the output is not the same as the input
424+
l === l′ || copyto!(l, l′)
425+
q === q′ || copyto!(q, q′)
426+
return nothing
427+
end
428+
429+
return L, Q
430+
end
431+
432+
function MatrixAlgebraKit.lq_compact!(t::AbstractTensorMap, (L, Q),
433+
alg::BlockAlgorithm)
434+
MatrixAlgebraKit.check_input(lq_compact!, t, (L, Q))
435+
436+
foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q)
437+
l′, q′ = lq_compact!(b, (l, q), alg.alg)
438+
# deal with the case where the output is not the same as the input
439+
l === l′ || copyto!(l, l′)
440+
q === q′ || copyto!(q, q′)
441+
return nothing
442+
end
443+
444+
return L, Q
445+
end
446+
447+
function MatrixAlgebraKit.lq_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm)
448+
MatrixAlgebraKit.check_input(lq_null!, t, N)
449+
450+
foreachblock(t, N; alg.scheduler) do _, (b, n)
451+
n′ = lq_null!(b, n, alg.alg)
452+
# deal with the case where the output is not the same as the input
453+
n === n′ || copyto!(n, n′)
454+
return nothing
455+
end
456+
457+
return N
458+
end
459+
460+
# Polar decomposition
461+
# -------------------
462+
using MatrixAlgebraKit: PolarViaSVD
463+
464+
function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P))
465+
codomain(t) domain(t) ||
466+
throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`"))
467+
468+
# scalartype checks
469+
@check_eltype W t
470+
@check_eltype P t
471+
472+
# space checks
473+
space(W) == (codomain(t) fuse(domain(t))) ||
474+
throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`"))
475+
space(P) == (fuse(domain(t)) domain(t)) ||
476+
throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`"))
477+
478+
return nothing
479+
end
480+
481+
# TODO: do we really not want to fuse the spaces?
482+
function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap)
483+
W = similar(t, codomain(t) fuse(domain(t)))
484+
P = similar(t, fuse(domain(t)) domain(t))
485+
return W, P
486+
end
487+
488+
function MatrixAlgebraKit.left_polar!(t::AbstractTensorMap, WP, alg::BlockAlgorithm)
489+
MatrixAlgebraKit.check_input(left_polar!, t, WP)
490+
491+
foreachblock(t, WP...; alg.scheduler) do _, (b, w, p)
492+
w′, p′ = left_polar!(b, (w, p), alg.alg)
493+
# deal with the case where the output is not the same as the input
494+
w === w′ || copyto!(w, w′)
495+
p === p′ || copyto!(p, p′)
496+
return nothing
497+
end
498+
499+
return WP
500+
end
501+
502+
function MatrixAlgebraKit.default_polar_algorithm(t::AbstractTensorMap{<:BlasFloat};
503+
scheduler=default_blockscheduler(t),
504+
kwargs...)
505+
return BlockAlgorithm(PolarViaSVD(LAPACK_DivideAndConquer(; kwargs...)),
506+
scheduler)
507+
end
508+
509+
# Orthogonalization
510+
# -----------------
511+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C))
512+
# scalartype checks
513+
@check_eltype V t
514+
isnothing(C) || @check_eltype C t
515+
516+
# space checks
517+
V_C = infimum(fuse(codomain(t)), fuse(domain(t)))
518+
space(V) == (codomain(t) V_C) ||
519+
throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(V) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t))))`"))
520+
isnothing(C) || space(C) == (V_C domain(t)) ||
521+
throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(C) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`"))
522+
523+
return nothing
524+
end
525+
526+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ))
527+
# scalartype checks
528+
isnothing(C) || @check_eltype C t
529+
@check_eltype Vᴴ t
530+
531+
# space checks
532+
V_C = infimum(fuse(codomain(t)), fuse(domain(t)))
533+
isnothing(C) || space(C) == (codomain(t) V_C) ||
534+
throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(C) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`"))
535+
space(Vᴴ) == (V_dom domain(t)) ||
536+
throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`"))
537+
538+
return nothing
539+
end
540+
541+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), t::AbstractTensorMap)
542+
V_C = infimum(fuse(codomain(t)), fuse(domain(t)))
543+
V = similar(t, codomain(t) V_C)
544+
C = similar(t, V_C domain(t))
545+
return V, C
546+
end
547+
548+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTensorMap)
549+
V_C = infimum(fuse(codomain(t)), fuse(domain(t)))
550+
C = similar(t, codomain(t) V_C)
551+
Vᴴ = similar(t, V_C domain(t))
552+
return C, Vᴴ
553+
end
554+
555+
function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; kwargs...)
556+
MatrixAlgebraKit.check_input(left_orth!, t, VC)
557+
atol = get(kwargs, :atol, 0)
558+
rtol = get(kwargs, :rtol, 0)
559+
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
560+
561+
if !(iszero(atol) && iszero(rtol)) && kind != :svd
562+
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
563+
end
564+
565+
if kind == :qr
566+
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(qr_compact!, t))
567+
return qr_compact!(t, VC, alg)
568+
elseif kind == :qrpos
569+
alg = get(kwargs, :alg,
570+
MatrixAlgebraKit.select_algorithm(qr_compact!, t; positive=true))
571+
return qr_compact!(t, VC, alg)
572+
elseif kind == :polar
573+
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(left_polar!, t))
574+
return left_polar!(t, VC, alg)
575+
elseif kind == :svd && iszero(atol) && iszero(rtol)
576+
alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(svd_compact!, t))
577+
V, C = VC
578+
S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) codomain(C))
579+
U, S, Vᴴ = svd_compact!(t, (V, S, C), alg)
580+
return U, lmul!(S, Vᴴ)
581+
elseif kind == :svd
582+
alg_svd = MatrixAlgebraKit.select_algorithm(svd_compact!, t)
583+
trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol)
584+
alg = get(kwargs, :alg, MatrixAlgebraKit.TruncatedAlgorithm(alg_svd, trunc))
585+
V, C = VC
586+
S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) codomain(C))
587+
U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg)
588+
return U, lmul!(S, Vᴴ)
589+
else
590+
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
591+
end
592+
end
593+
594+
# Truncation
595+
# ----------
596+
# TODO: technically we could do this truncation in-place, but this might not be worth it
597+
function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ),
598+
trunc::MatrixAlgebraKit.TruncationKeepAbove)
599+
atol = max(trunc.atol, norm(S) * trunc.rtol)
600+
V_truncated = spacetype(S)(c => findlast(>=(atol), b.diag) for (c, b) in blocks(S))
601+
602+
= similar(U, codomain(U) V_truncated)
603+
for (c, b) in blocks(Ũ)
604+
copy!(b, @view(block(U, c)[:, 1:size(b, 2)]))
605+
end
606+
607+
= DiagonalTensorMap{scalartype(S)}(undef, V_truncated)
608+
for (c, b) in blocks(S̃)
609+
copy!(b.diag, @view(block(S, c).diag[1:size(b, 1)]))
610+
end
611+
612+
Ṽᴴ = similar(Vᴴ, V_truncated domain(Vᴴ))
613+
for (c, b) in blocks(Ṽᴴ)
614+
copy!(b, @view(block(Vᴴ, c)[1:size(b, 1), :]))
615+
end
616+
617+
return Ũ, S̃, Ṽᴴ
618+
end

0 commit comments

Comments
 (0)