@@ -21,6 +21,18 @@ macro check_eltype(x, y, f=:identity, g=:eltype)
2121 return esc (:($ g ($ x) == $ f ($ g ($ y)) || throw (ArgumentError ($ msg))))
2222end
2323
24+ function MatrixAlgebraKit. _select_algorithm (_, :: AbstractTensorMap ,
25+ alg:: MatrixAlgebraKit.AbstractAlgorithm )
26+ return alg
27+ end
28+ function MatrixAlgebraKit. _select_algorithm (f, t:: AbstractTensorMap , alg:: NamedTuple )
29+ return MatrixAlgebraKit. select_algorithm (f, t; alg... )
30+ end
31+
32+ function _select_truncation (f, :: AbstractTensorMap ,
33+ trunc:: MatrixAlgebraKit.TruncationStrategy )
34+ return trunc
35+ end
2436# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
2537# T = scalartype(t)
2638# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
@@ -76,7 +88,7 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTens
7688 :: MatrixAlgebraKit.AbstractAlgorithm )
7789 V_cod = fuse (codomain (t))
7890 V_dom = fuse (domain (t))
79- U = similar (t, domain (t) ← V_cod)
91+ U = similar (t, codomain (t) ← V_cod)
8092 S = similar (t, real (scalartype (t)), V_cod ← V_dom)
8193 Vᴴ = similar (t, V_dom ← domain (t))
8294 return U, S, Vᴴ
@@ -476,18 +488,19 @@ function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P))
476488 @check_eltype P t
477489
478490 # space checks
479- space (W) == ( codomain (t) ← fuse ( domain (t)) ) ||
491+ space (W) == space (t ) ||
480492 throw (SpaceMismatch (" `left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`" ))
481- space (P) == (fuse ( domain (t) ) ← domain (t)) ||
493+ space (P) == (domain (t) ← domain (t)) ||
482494 throw (SpaceMismatch (" `left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`" ))
483495
484496 return nothing
485497end
486498
487499# TODO : do we really not want to fuse the spaces?
488- function MatrixAlgebraKit. initialize_output (:: typeof (left_polar!), t:: AbstractTensorMap )
489- W = similar (t, codomain (t) ← fuse (domain (t)))
490- P = similar (t, fuse (domain (t)) ← domain (t))
500+ function MatrixAlgebraKit. initialize_output (:: typeof (left_polar!), t:: AbstractTensorMap ,
501+ :: MatrixAlgebraKit.AbstractAlgorithm )
502+ W = similar (t, space (t))
503+ P = similar (t, domain (t) ← domain (t))
491504 return W, P
492505end
493506
@@ -558,40 +571,48 @@ function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTe
558571 return C, Vᴴ
559572end
560573
561- function MatrixAlgebraKit. left_orth! (t:: AbstractTensorMap , VC; kwargs... )
562- MatrixAlgebraKit. check_input (left_orth!, t, VC)
563- atol = get (kwargs, :atol , 0 )
564- rtol = get (kwargs, :rtol , 0 )
565- kind = get (kwargs, :kind , iszero (atol) && iszero (rtol) ? :qrpos : :svd )
566-
567- if ! (iszero (atol) && iszero (rtol)) && kind != :svd
568- throw (ArgumentError (" nonzero tolerance not supported for left_orth with kind=$kind " ))
574+ function MatrixAlgebraKit. left_orth! (t:: AbstractTensorMap , VC;
575+ trunc= nothing ,
576+ kind= isnothing (trunc) ?
577+ :qr : :svd ,
578+ alg_qr= (; positive= true ),
579+ alg_polar= (;),
580+ alg_svd= (;))
581+ if ! isnothing (trunc) && kind != :svd
582+ throw (ArgumentError (" truncation not supported for left_orth with kind=$kind " ))
569583 end
570584
571585 if kind == :qr
572- alg = get (kwargs, :alg , MatrixAlgebraKit. select_algorithm (qr_compact!, t) )
573- return qr_compact! (t, VC, alg )
574- elseif kind == :qrpos
575- alg = get (kwargs, :alg ,
576- MatrixAlgebraKit . select_algorithm (qr_compact!, t; positive = true ))
577- return qr_compact! (t, VC, alg )
578- elseif kind == :polar
579- alg = get (kwargs, :alg , MatrixAlgebraKit . select_algorithm (left_polar!, t))
580- return left_polar! (t, VC, alg)
581- elseif kind == :svd && iszero (atol) && iszero (rtol )
582- alg = get (kwargs, :alg , MatrixAlgebraKit. select_algorithm (svd_compact!, t) )
586+ alg_qr′ = MatrixAlgebraKit. _select_algorithm (qr_compact!, t, alg_qr )
587+ return qr_compact! (t, VC, alg_qr′ )
588+ end
589+
590+ if kind == :polar
591+ alg_polar′ = MatrixAlgebraKit . _select_algorithm (left_polar!, t, alg_polar )
592+ return left_polar! (t, VC, alg_polar′)
593+ end
594+
595+ if kind == :svd && isnothing (trunc )
596+ alg_svd′ = MatrixAlgebraKit. _select_algorithm (svd_compact!, t, alg_svd )
583597 V, C = VC
584598 S = DiagonalTensorMap {real(scalartype(t))} (undef, domain (V) ← codomain (C))
585- U, S, Vᴴ = svd_compact! (t, (V, S, C), alg )
599+ U, S, Vᴴ = svd_compact! (t, (V, S, C), alg_svd′ )
586600 return U, lmul! (S, Vᴴ)
587- elseif kind == :svd
588- alg_svd = MatrixAlgebraKit. select_algorithm (svd_compact!, t)
589- trunc = MatrixAlgebraKit. TruncationKeepAbove (atol, rtol)
590- alg = get (kwargs, :alg , MatrixAlgebraKit. TruncatedAlgorithm (alg_svd, trunc))
601+ end
602+
603+ if kind == :svd
604+ alg_svd′ = MatrixAlgebraKit. _select_algorithm (svd_compact!, t, alg_svd)
605+ alg_svd_trunc = MatrixAlgebraKit. select_algorithm (svd_trunc!, t; trunc,
606+ alg= alg_svd′)
591607 V, C = VC
592608 S = DiagonalTensorMap {real(scalartype(t))} (undef, domain (V) ← codomain (C))
593- U, S, Vᴴ = svd_trunc! (t, (V, S, C), alg )
609+ U, S, Vᴴ = svd_trunc! (t, (V, S, C), alg_svd_trunc )
594610 return U, lmul! (S, Vᴴ)
611+ end
612+
613+ throw (ArgumentError (" `left_orth!` received unknown value `kind = $kind `" ))
614+ end
615+
595616 else
596617 throw (ArgumentError (" `left_orth!` received unknown value `kind = $kind `" ))
597618 end
0 commit comments