@@ -65,6 +65,47 @@ function truncate!(::typeof(left_null!),
6565 return Ũ
6666end
6767
68+ function truncate! (:: typeof (eigh_trunc!), (D, V):: _T_DV , strategy:: TruncationStrategy )
69+ ind = findtruncated (diagview (D), strategy)
70+ V_truncated = spacetype (D)(c => length (I) for (c, I) in ind)
71+
72+ D̃ = DiagonalTensorMap {scalartype(D)} (undef, V_truncated)
73+ for (c, b) in blocks (D̃)
74+ I = get (ind, c, nothing )
75+ @assert ! isnothing (I)
76+ copy! (b. diag, @view (block (D, c). diag[I]))
77+ end
78+
79+ Ṽ = similar (V, V_truncated ← domain (V))
80+ for (c, b) in blocks (Ṽ)
81+ I = get (ind, c, nothing )
82+ @assert ! isnothing (I)
83+ copy! (b, @view (block (V, c)[I, :]))
84+ end
85+
86+ return D̃, Ṽ
87+ end
88+ function truncate! (:: typeof (eig_trunc!), (D, V):: _T_DV , strategy:: TruncationStrategy )
89+ ind = findtruncated (diagview (D), strategy)
90+ V_truncated = spacetype (D)(c => length (I) for (c, I) in ind)
91+
92+ D̃ = DiagonalTensorMap {scalartype(D)} (undef, V_truncated)
93+ for (c, b) in blocks (D̃)
94+ I = get (ind, c, nothing )
95+ @assert ! isnothing (I)
96+ copy! (b. diag, @view (block (D, c). diag[I]))
97+ end
98+
99+ Ṽ = similar (V, V_truncated ← domain (V))
100+ for (c, b) in blocks (Ṽ)
101+ I = get (ind, c, nothing )
102+ @assert ! isnothing (I)
103+ copy! (b, @view (block (V, c)[I, :]))
104+ end
105+
106+ return D̃, Ṽ
107+ end
108+
68109# Find truncation
69110# ---------------
70111# auxiliary functions
@@ -88,18 +129,28 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector}
88129 return σmin, keys (truncdim)[imin]
89130end
90131
91- # sorted implementations
132+ # implementations
92133function findtruncated_sorted (S:: SectorDict , strategy:: TruncationKeepAbove )
93134 atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
94135 findtrunc = Base. Fix2 (findtruncated_sorted, truncbelow (atol))
95136 return SectorDict (c => findtrunc (d) for (c, d) in Sd)
96137end
138+ function findtruncated (S:: SectorDict , strategy:: TruncationKeepAbove )
139+ atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
140+ findtrunc = Base. Fix2 (findtruncated, truncbelow (atol))
141+ return SectorDict (c => findtrunc (d) for (c, d) in Sd)
142+ end
97143
98144function findtruncated_sorted (S:: SectorDict , strategy:: TruncationKeepBelow )
99145 atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
100146 findtrunc = Base. Fix2 (findtruncated_sorted, truncabove (atol))
101147 return SectorDict (c => findtrunc (d) for (c, d) in Sd)
102148end
149+ function findtruncated (S:: SectorDict , strategy:: TruncationKeepBelow )
150+ atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
151+ findtrunc = Base. Fix2 (findtruncated, truncabove (atol))
152+ return SectorDict (c => findtrunc (d) for (c, d) in Sd)
153+ end
103154
104155function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationError )
105156 I = keytype (Sd)
153204function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationKeepFiltered )
154205 return SectorDict (c => findtruncated_sorted (d, strategy) for (c, d) in Sd)
155206end
207+ function findtruncated (Sd:: SectorDict , strategy:: TruncationKeepFiltered )
208+ return SectorDict (c => findtruncated (d, strategy) for (c, d) in Sd)
209+ end
156210
157211function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationIntersection )
158212 inds = map (Base. Fix1 (findtruncated_sorted, Sd), strategy)
159213 return SectorDict (c => intersect (map (Base. Fix2 (getindex, c), inds)... )
160214 for c in intersect (map (keys, inds)... ))
161215end
216+ function findtruncated (Sd:: SectorDict , strategy:: TruncationIntersection )
217+ inds = map (Base. Fix1 (findtruncated, Sd), strategy)
218+ return SectorDict (c => intersect (map (Base. Fix2 (getindex, c), inds)... )
219+ for c in intersect (map (keys, inds)... ))
220+ end
0 commit comments