@@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
4545 return indexmask
4646end
4747
48+ function similar_truncate (
49+ :: typeof (svd_trunc!),
50+ (U, S, Vᴴ):: TBlockUSV ᴴ,
51+ strategy:: BlockPermutedDiagonalTruncationStrategy ,
52+ indexmask= MatrixAlgebraKit. findtruncated (diagview (S), strategy),
53+ )
54+ ax = axes (S, 1 )
55+ counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
56+ s_lengths = filter! (> (0 ), map (counter, blocks (ax)))
57+ s_axis = blockedrange (s_lengths)
58+ Ũ = similar (U, axes (U, 1 ), s_axis)
59+ S̃ = similar (S, s_axis, s_axis)
60+ Ṽᴴ = similar (Vᴴ, s_axis, axes (Vᴴ, 2 ))
61+ return Ũ, S̃, Ṽᴴ
62+ end
63+
4864function MatrixAlgebraKit. truncate! (
4965 :: typeof (svd_trunc!),
5066 (U, S, Vᴴ):: TBlockUSV ᴴ,
@@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(
5470
5571 # first determine the block structure of the output to avoid having assumptions on the
5672 # data structures
57- ax = axes (S, 1 )
58- counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
59- Slengths = filter! (> (0 ), map (counter, blocks (ax)))
60- Sax = blockedrange (Slengths)
61- Ũ = similar (U, axes (U, 1 ), Sax)
62- S̃ = similar (S, Sax, Sax)
63- Ṽᴴ = similar (Vᴴ, Sax, axes (Vᴴ, 2 ))
73+ Ũ, S̃, Ṽᴴ = similar_truncate (svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
6474
6575 # then loop over the blocks and assign the data
6676 # TODO : figure out if we can presort and loop over the blocks -
@@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
7080 bI_Vᴴs = collect (eachblockstoredindex (Vᴴ))
7181
7282 I′ = 0 # number of skipped blocks that got fully truncated
83+ ax = axes (S, 1 )
7384 for I in 1 : blocksize (ax, 1 )
7485 b = ax[Block (I)]
7586 mask = indexmask[b]
0 commit comments