-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtruncation.jl
More file actions
55 lines (48 loc) · 1.78 KB
/
truncation.jl
File metadata and controls
55 lines (48 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
D = BlockSparseVector{T}(undef, axes(A, 1))
for I in eachblockstoredindex(A)
if ==(Int.(Tuple(I))...)
D[Tuple(I)[1]] = diagview(A[I])
end
end
return D
end
"""
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
block-diagonal matrix.
"""
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
strategy::T
end
const TBlockUSVᴴ = Tuple{
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
}
function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
)
# TODO assert blockdiagonal
return MatrixAlgebraKit.truncate!(
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
)
end
# cannot use regular slicing here: I want to slice without altering blockstructure
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
)
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
indexmask = falses(length(values))
indexmask[ind] .= true
return indexmask
end
function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
)
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
return (U[:, I], S[I, I], Vᴴ[I, :])
end