@@ -223,3 +223,396 @@ function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloa
223223 return BlockAlgorithm (LAPACK_MultipleRelativelyRobustRepresentations (; kwargs... ),
224224 scheduler)
225225end
226+
227+ # QR decomposition
228+ # ----------------
229+ function MatrixAlgebraKit. check_input (:: typeof (qr_full!), t:: AbstractTensorMap ,
230+ (Q,
231+ R):: Tuple{<:AbstractTensorMap,<:AbstractTensorMap} )
232+ # scalartype checks
233+ @check_eltype Q t
234+ @check_eltype R t
235+
236+ # space checks
237+ V_Q = fuse (codomain (t))
238+ space (Q) == (codomain (t) ← V_Q) ||
239+ throw (SpaceMismatch (" `qr_full!(t, (Q, R))` requires `space(Q) == (codomain(t) ← fuse(codomain(t)))`" ))
240+ space (R) == (V_Q ← domain (t)) ||
241+ throw (SpaceMismatch (" `qr_full!(t, (Q, R))` requires `space(R) == (fuse(codomain(t)) ← domain(t)`" ))
242+
243+ return nothing
244+ end
245+
246+ function MatrixAlgebraKit. check_input (:: typeof (qr_compact!), t:: AbstractTensorMap , (Q, R))
247+ # scalartype checks
248+ @check_eltype Q t
249+ @check_eltype R t
250+
251+ # space checks
252+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
253+ space (Q) == (codomain (t) ← V_Q) ||
254+ throw (SpaceMismatch (" `qr_compact!(t, (Q, R))` requires `space(Q) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
255+ space (R) == (V_Q ← domain (t)) ||
256+ throw (SpaceMismatch (" `qr_compact!(t, (Q, R))` requires `space(R) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`" ))
257+
258+ return nothing
259+ end
260+
261+ function MatrixAlgebraKit. check_input (:: typeof (qr_null!), t:: AbstractTensorMap ,
262+ N:: AbstractTensorMap )
263+ # scalartype checks
264+ @check_eltype N t
265+
266+ # space checks
267+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
268+ V_N = setdiff (fuse (codomain (t)), V_Q)
269+ space (N) == (codomain (t) ← V_N) ||
270+ throw (SpaceMismatch (" `qr_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`" ))
271+
272+ return nothing
273+ end
274+
275+ function MatrixAlgebraKit. initialize_output (:: typeof (qr_full!), t:: AbstractTensorMap ,
276+ :: MatrixAlgebraKit.AbstractAlgorithm )
277+ V_Q = fuse (codomain (t))
278+ Q = similar (t, codomain (t) ← V_Q)
279+ R = similar (t, V_Q ← domain (t))
280+ return Q, R
281+ end
282+
283+ function MatrixAlgebraKit. initialize_output (:: typeof (qr_compact!), t:: AbstractTensorMap ,
284+ :: MatrixAlgebraKit.AbstractAlgorithm )
285+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
286+ Q = similar (t, codomain (t) ← V_Q)
287+ R = similar (t, V_Q ← domain (t))
288+ return Q, R
289+ end
290+
291+ function MatrixAlgebraKit. initialize_output (:: typeof (qr_null!), t:: AbstractTensorMap ,
292+ :: MatrixAlgebraKit.AbstractAlgorithm )
293+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
294+ V_N = setdiff (fuse (codomain (t)), V_Q)
295+ N = similar (t, codomain (t) ← V_N)
296+ return N
297+ end
298+
299+ function MatrixAlgebraKit. qr_full! (t:: AbstractTensorMap , (Q, R),
300+ alg:: BlockAlgorithm )
301+ MatrixAlgebraKit. check_input (qr_full!, t, (Q, R))
302+
303+ foreachblock (t, Q, R; alg. scheduler) do _, (b, q, r)
304+ q′, r′ = qr_full! (b, (q, r), alg. alg)
305+ # deal with the case where the output is not the same as the input
306+ q === q′ || copyto! (q, q′)
307+ r === r′ || copyto! (r, r′)
308+ return nothing
309+ end
310+
311+ return Q, R
312+ end
313+
314+ function MatrixAlgebraKit. qr_compact! (t:: AbstractTensorMap , (Q, R),
315+ alg:: BlockAlgorithm )
316+ MatrixAlgebraKit. check_input (qr_compact!, t, (Q, R))
317+
318+ foreachblock (t, Q, R; alg. scheduler) do _, (b, q, r)
319+ q′, r′ = qr_compact! (b, (q, r), alg. alg)
320+ # deal with the case where the output is not the same as the input
321+ q === q′ || copyto! (q, q′)
322+ r === r′ || copyto! (r, r′)
323+ return nothing
324+ end
325+
326+ return Q, R
327+ end
328+
329+ function MatrixAlgebraKit. qr_null! (t:: AbstractTensorMap , N, alg:: BlockAlgorithm )
330+ MatrixAlgebraKit. check_input (qr_null!, t, N)
331+
332+ foreachblock (t, N; alg. scheduler) do _, (b, n)
333+ n′ = qr_null! (b, n, alg. alg)
334+ # deal with the case where the output is not the same as the input
335+ n === n′ || copyto! (n, n′)
336+ return nothing
337+ end
338+
339+ return N
340+ end
341+
342+ function MatrixAlgebraKit. default_qr_algorithm (t:: AbstractTensorMap{<:BlasFloat} ;
343+ scheduler= default_blockscheduler (t),
344+ kwargs... )
345+ return BlockAlgorithm (LAPACK_HouseholderQR (; kwargs... ), scheduler)
346+ end
347+
348+ # LQ decomposition
349+ # ----------------
350+ function MatrixAlgebraKit. check_input (:: typeof (lq_full!), t:: AbstractTensorMap , (L, Q))
351+ # scalartype checks
352+ @check_eltype L t
353+ @check_eltype Q t
354+
355+ # space checks
356+ V_Q = fuse (domain (t))
357+ space (L) == (codomain (t) ← V_Q) ||
358+ throw (SpaceMismatch (" `lq_full!(t, (L, Q))` requires `space(L) == (codomain(t) ← fuse(domain(t)))`" ))
359+ space (Q) == (V_Q ← domain (t)) ||
360+ throw (SpaceMismatch (" `lq_full!(t, (L, Q))` requires `space(Q) == (fuse(domain(t)) ← domain(t))`" ))
361+
362+ return nothing
363+ end
364+
365+ function MatrixAlgebraKit. check_input (:: typeof (lq_compact!), t:: AbstractTensorMap , (L, Q))
366+ # scalartype checks
367+ @check_eltype L t
368+ @check_eltype Q t
369+
370+ # space checks
371+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
372+ space (L) == (codomain (t) ← V_Q) ||
373+ throw (SpaceMismatch (" `lq_compact!(t, (L, Q))` requires `space(L) == infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
374+ space (Q) == (V_Q ← domain (t)) ||
375+ throw (SpaceMismatch (" `lq_compact!(t, (L, Q))` requires `space(Q) == infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
376+
377+ return nothing
378+ end
379+
380+ function MatrixAlgebraKit. check_input (:: typeof (lq_null!), t:: AbstractTensorMap , N)
381+ # scalartype checks
382+ @check_eltype N t
383+
384+ # space checks
385+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
386+ V_N = setdiff (fuse (domain (t)), V_Q)
387+ space (N) == (V_N ← domain (t)) ||
388+ throw (SpaceMismatch (" `lq_null!(t, N)` requires `space(N) == setdiff(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
389+
390+ return nothing
391+ end
392+
393+ function MatrixAlgebraKit. initialize_output (:: typeof (lq_full!), t:: AbstractTensorMap ,
394+ :: MatrixAlgebraKit.AbstractAlgorithm )
395+ V_Q = fuse (domain (t))
396+ L = similar (t, codomain (t) ← V_Q)
397+ Q = similar (t, V_Q ← domain (t))
398+ return L, Q
399+ end
400+
401+ function MatrixAlgebraKit. initialize_output (:: typeof (lq_compact!), t:: AbstractTensorMap ,
402+ :: MatrixAlgebraKit.AbstractAlgorithm )
403+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
404+ L = similar (t, codomain (t) ← V_Q)
405+ Q = similar (t, V_Q ← domain (t))
406+ return L, Q
407+ end
408+
409+ function MatrixAlgebraKit. initialize_output (:: typeof (lq_null!), t:: AbstractTensorMap ,
410+ :: MatrixAlgebraKit.AbstractAlgorithm )
411+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
412+ V_N = setdiff (fuse (domain (t)), V_Q)
413+ N = similar (t, V_N ← domain (t))
414+ return N
415+ end
416+
417+ function MatrixAlgebraKit. lq_full! (t:: AbstractTensorMap , (L, Q),
418+ alg:: BlockAlgorithm )
419+ MatrixAlgebraKit. check_input (lq_full!, t, (L, Q))
420+
421+ foreachblock (t, L, Q; alg. scheduler) do _, (b, l, q)
422+ l′, q′ = lq_full! (b, (l, q), alg. alg)
423+ # deal with the case where the output is not the same as the input
424+ l === l′ || copyto! (l, l′)
425+ q === q′ || copyto! (q, q′)
426+ return nothing
427+ end
428+
429+ return L, Q
430+ end
431+
432+ function MatrixAlgebraKit. lq_compact! (t:: AbstractTensorMap , (L, Q),
433+ alg:: BlockAlgorithm )
434+ MatrixAlgebraKit. check_input (lq_compact!, t, (L, Q))
435+
436+ foreachblock (t, L, Q; alg. scheduler) do _, (b, l, q)
437+ l′, q′ = lq_compact! (b, (l, q), alg. alg)
438+ # deal with the case where the output is not the same as the input
439+ l === l′ || copyto! (l, l′)
440+ q === q′ || copyto! (q, q′)
441+ return nothing
442+ end
443+
444+ return L, Q
445+ end
446+
447+ function MatrixAlgebraKit. lq_null! (t:: AbstractTensorMap , N, alg:: BlockAlgorithm )
448+ MatrixAlgebraKit. check_input (lq_null!, t, N)
449+
450+ foreachblock (t, N; alg. scheduler) do _, (b, n)
451+ n′ = lq_null! (b, n, alg. alg)
452+ # deal with the case where the output is not the same as the input
453+ n === n′ || copyto! (n, n′)
454+ return nothing
455+ end
456+
457+ return N
458+ end
459+
460+ # Polar decomposition
461+ # -------------------
462+ using MatrixAlgebraKit: PolarViaSVD
463+
464+ function MatrixAlgebraKit. check_input (:: typeof (left_polar!), t, (W, P))
465+ codomain (t) ≿ domain (t) ||
466+ throw (ArgumentError (" Polar decomposition requires `codomain(t) ≿ domain(t)`" ))
467+
468+ # scalartype checks
469+ @check_eltype W t
470+ @check_eltype P t
471+
472+ # space checks
473+ space (W) == (codomain (t) ← fuse (domain (t))) ||
474+ throw (SpaceMismatch (" `left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`" ))
475+ space (P) == (fuse (domain (t)) ← domain (t)) ||
476+ throw (SpaceMismatch (" `left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`" ))
477+
478+ return nothing
479+ end
480+
481+ # TODO : do we really not want to fuse the spaces?
482+ function MatrixAlgebraKit. initialize_output (:: typeof (left_polar!), t:: AbstractTensorMap )
483+ W = similar (t, codomain (t) ← fuse (domain (t)))
484+ P = similar (t, fuse (domain (t)) ← domain (t))
485+ return W, P
486+ end
487+
488+ function MatrixAlgebraKit. left_polar! (t:: AbstractTensorMap , WP, alg:: BlockAlgorithm )
489+ MatrixAlgebraKit. check_input (left_polar!, t, WP)
490+
491+ foreachblock (t, WP... ; alg. scheduler) do _, (b, w, p)
492+ w′, p′ = left_polar! (b, (w, p), alg. alg)
493+ # deal with the case where the output is not the same as the input
494+ w === w′ || copyto! (w, w′)
495+ p === p′ || copyto! (p, p′)
496+ return nothing
497+ end
498+
499+ return WP
500+ end
501+
502+ function MatrixAlgebraKit. default_polar_algorithm (t:: AbstractTensorMap{<:BlasFloat} ;
503+ scheduler= default_blockscheduler (t),
504+ kwargs... )
505+ return BlockAlgorithm (PolarViaSVD (LAPACK_DivideAndConquer (; kwargs... )),
506+ scheduler)
507+ end
508+
509+ # Orthogonalization
510+ # -----------------
511+ function MatrixAlgebraKit. check_input (:: typeof (left_orth!), t:: AbstractTensorMap , (V, C))
512+ # scalartype checks
513+ @check_eltype V t
514+ isnothing (C) || @check_eltype C t
515+
516+ # space checks
517+ V_C = infimum (fuse (codomain (t)), fuse (domain (t)))
518+ space (V) == (codomain (t) ← V_C) ||
519+ throw (SpaceMismatch (" `left_orth!(t, (V, C))` requires `space(V) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t))))`" ))
520+ isnothing (C) || space (C) == (V_C ← domain (t)) ||
521+ throw (SpaceMismatch (" `left_orth!(t, (V, C))` requires `space(C) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`" ))
522+
523+ return nothing
524+ end
525+
526+ function MatrixAlgebraKit. check_input (:: typeof (right_orth!), t:: AbstractTensorMap , (C, Vᴴ))
527+ # scalartype checks
528+ isnothing (C) || @check_eltype C t
529+ @check_eltype Vᴴ t
530+
531+ # space checks
532+ V_C = infimum (fuse (codomain (t)), fuse (domain (t)))
533+ isnothing (C) || space (C) == (codomain (t) ← V_C) ||
534+ throw (SpaceMismatch (" `right_orth!(t, (C, Vᴴ))` requires `space(C) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
535+ space (Vᴴ) == (V_dom ← domain (t)) ||
536+ throw (SpaceMismatch (" `right_orth!(t, (C, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`" ))
537+
538+ return nothing
539+ end
540+
541+ function MatrixAlgebraKit. initialize_output (:: typeof (left_orth!), t:: AbstractTensorMap )
542+ V_C = infimum (fuse (codomain (t)), fuse (domain (t)))
543+ V = similar (t, codomain (t) ← V_C)
544+ C = similar (t, V_C ← domain (t))
545+ return V, C
546+ end
547+
548+ function MatrixAlgebraKit. initialize_output (:: typeof (right_orth!), t:: AbstractTensorMap )
549+ V_C = infimum (fuse (codomain (t)), fuse (domain (t)))
550+ C = similar (t, codomain (t) ← V_C)
551+ Vᴴ = similar (t, V_C ← domain (t))
552+ return C, Vᴴ
553+ end
554+
555+ function MatrixAlgebraKit. left_orth! (t:: AbstractTensorMap , VC; kwargs... )
556+ MatrixAlgebraKit. check_input (left_orth!, t, VC)
557+ atol = get (kwargs, :atol , 0 )
558+ rtol = get (kwargs, :rtol , 0 )
559+ kind = get (kwargs, :kind , iszero (atol) && iszero (rtol) ? :qrpos : :svd )
560+
561+ if ! (iszero (atol) && iszero (rtol)) && kind != :svd
562+ throw (ArgumentError (" nonzero tolerance not supported for left_orth with kind=$kind " ))
563+ end
564+
565+ if kind == :qr
566+ alg = get (kwargs, :alg , MatrixAlgebraKit. select_algorithm (qr_compact!, t))
567+ return qr_compact! (t, VC, alg)
568+ elseif kind == :qrpos
569+ alg = get (kwargs, :alg ,
570+ MatrixAlgebraKit. select_algorithm (qr_compact!, t; positive= true ))
571+ return qr_compact! (t, VC, alg)
572+ elseif kind == :polar
573+ alg = get (kwargs, :alg , MatrixAlgebraKit. select_algorithm (left_polar!, t))
574+ return left_polar! (t, VC, alg)
575+ elseif kind == :svd && iszero (atol) && iszero (rtol)
576+ alg = get (kwargs, :alg , MatrixAlgebraKit. select_algorithm (svd_compact!, t))
577+ V, C = VC
578+ S = DiagonalTensorMap {real(scalartype(t))} (undef, domain (V) ← codomain (C))
579+ U, S, Vᴴ = svd_compact! (t, (V, S, C), alg)
580+ return U, lmul! (S, Vᴴ)
581+ elseif kind == :svd
582+ alg_svd = MatrixAlgebraKit. select_algorithm (svd_compact!, t)
583+ trunc = MatrixAlgebraKit. TruncationKeepAbove (atol, rtol)
584+ alg = get (kwargs, :alg , MatrixAlgebraKit. TruncatedAlgorithm (alg_svd, trunc))
585+ V, C = VC
586+ S = DiagonalTensorMap {real(scalartype(t))} (undef, domain (V) ← codomain (C))
587+ U, S, Vᴴ = svd_trunc! (t, (V, S, C), alg)
588+ return U, lmul! (S, Vᴴ)
589+ else
590+ throw (ArgumentError (" `left_orth!` received unknown value `kind = $kind `" ))
591+ end
592+ end
593+
594+ # Truncation
595+ # ----------
596+ # TODO : technically we could do this truncation in-place, but this might not be worth it
597+ function MatrixAlgebraKit. truncate! (:: typeof (svd_trunc!), (U, S, Vᴴ),
598+ trunc:: MatrixAlgebraKit.TruncationKeepAbove )
599+ atol = max (trunc. atol, norm (S) * trunc. rtol)
600+ V_truncated = spacetype (S)(c => findlast (>= (atol), b. diag) for (c, b) in blocks (S))
601+
602+ Ũ = similar (U, codomain (U) ← V_truncated)
603+ for (c, b) in blocks (Ũ)
604+ copy! (b, @view (block (U, c)[:, 1 : size (b, 2 )]))
605+ end
606+
607+ S̃ = DiagonalTensorMap {scalartype(S)} (undef, V_truncated)
608+ for (c, b) in blocks (S̃)
609+ copy! (b. diag, @view (block (S, c). diag[1 : size (b, 1 )]))
610+ end
611+
612+ Ṽᴴ = similar (Vᴴ, V_truncated ← domain (Vᴴ))
613+ for (c, b) in blocks (Ṽᴴ)
614+ copy! (b, @view (block (Vᴴ, c)[1 : size (b, 1 ), :]))
615+ end
616+
617+ return Ũ, S̃, Ṽᴴ
618+ end
0 commit comments