1- using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
2-
3- function MatrixAlgebraKit. diagview (A:: BlockSparseMatrix{T,Diagonal{T,Vector{T}}} ) where {T}
4- D = BlockSparseVector {T} (undef, axes (A, 1 ))
5- for I in eachblockstoredindex (A)
6- if == (Int .(Tuple (I))... )
7- D[Tuple (I)[1 ]] = diagview (A[I])
8- end
9- end
10- return D
11- end
1+ using MatrixAlgebraKit:
2+ TruncationStrategy,
3+ diagview,
4+ eig_trunc!,
5+ eigh_trunc!,
6+ findtruncated,
7+ svd_trunc!,
8+ truncate!
129
1310"""
1411 BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
@@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!(
2724 strategy:: TruncationStrategy ,
2825)
2926 # TODO assert blockdiagonal
30- return MatrixAlgebraKit . truncate! (
27+ return truncate! (
3128 svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy (strategy)
3229 )
3330end
@@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
3835 (D, V):: NTuple{2,AbstractBlockSparseMatrix} ,
3936 strategy:: TruncationStrategy ,
4037 )
41- return MatrixAlgebraKit. truncate! (
42- $ f, (D, V), BlockPermutedDiagonalTruncationStrategy (strategy)
43- )
38+ return truncate! ($ f, (D, V), BlockPermutedDiagonalTruncationStrategy (strategy))
4439 end
4540 end
4641end
5045function MatrixAlgebraKit. findtruncated (
5146 values:: AbstractVector , strategy:: BlockPermutedDiagonalTruncationStrategy
5247)
53- ind = MatrixAlgebraKit . findtruncated (values, strategy. strategy)
48+ ind = findtruncated (Vector ( values) , strategy. strategy)
5449 indexmask = falses (length (values))
5550 indexmask[ind] .= true
56- return indexmask
51+ return to_truncated_indices (values, indexmask)
52+ end
53+
54+ # Allow customizing the indices output by `findtruncated`
55+ # based on the type of `values`, for example to preserve
56+ # a block or Kronecker structure.
57+ to_truncated_indices (values, I) = I
58+ function to_truncated_indices (values:: AbstractBlockVector , I:: AbstractVector{Bool} )
59+ I′ = BlockedVector (I, blocklengths (axis (values)))
60+ blocks = map (BlockRange (values)) do b
61+ return _getindex (b, to_truncated_indices (values[b], I′[b]))
62+ end
63+ return blocks
5764end
5865
5966function MatrixAlgebraKit. truncate! (
6067 :: typeof (svd_trunc!),
6168 (U, S, Vᴴ):: NTuple{3,AbstractBlockSparseMatrix} ,
6269 strategy:: BlockPermutedDiagonalTruncationStrategy ,
6370)
64- I = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
71+ I = MatrixAlgebraKit. findtruncated (diag (S), strategy)
6572 return (U[:, I], S[I, I], Vᴴ[I, :])
6673end
6774for f in [:eig_trunc! , :eigh_trunc! ]
@@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
7178 (D, V):: NTuple{2,AbstractBlockSparseMatrix} ,
7279 strategy:: BlockPermutedDiagonalTruncationStrategy ,
7380 )
74- I = MatrixAlgebraKit. findtruncated (diagview (D), strategy)
81+ I = MatrixAlgebraKit. findtruncated (diag (D), strategy)
7582 return (D[I, I], V[:, I])
7683 end
7784 end
0 commit comments