Skip to content

Commit 1f384cf

Browse files
committed
improve sampling space generation for local expansion
1 parent dab6207 commit 1f384cf

2 files changed

Lines changed: 41 additions & 9 deletions

File tree

src/algorithms/changebonds/localexpand.jl

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
struct NoExpand <: Algorithm end
22

33
function changebonds_left(AL, Cs, alg; kwargs...)
4-
# @info "Old size: $(dim(right_virtualspace(AL)))"
54
AL, Cs = changebonds(; expand_rightspace = AL, embed_leftspace = Cs, alg, kwargs...)[[2,end]]
6-
# @info "New size: $(dim(right_virtualspace(AL)))"
75
return AL, Cs
86
end
97
function changebonds_right(Cs, AR, alg; kwargs...)
@@ -123,9 +121,46 @@ end
123121
extract_sector_types(::Type{GradedSpace{S,D}}) where {S<:Sector,D} = (S,)
124122
extract_sector_types(::Type{GradedSpace{ProductSector{T},D}}) where {T<:Tuple,D} = Tuple(T.parameters)
125123
extract_sector_types(sp::GradedSpace) = extract_sector_types(typeof(sp))
126-
function generate_sampling_space(psi::MPSKit.AbstractMPS, cutoff::Integer=100)
124+
function generate_sampling_space(psi::MPSKit.AbstractMPS; cutoff::Integer=100, minsize::Integer=1)
127125
sp = physicalspace(psi.AL[1])
128-
x = extract_sector_types(sp)
129-
iterator = Iterators.product((Iterators.take(values(T),cutoff) for T in x)...)
130-
return typeof(sp)([T=>1 for T in iterator])
126+
I = sectortype(sp)
127+
types = extract_sector_types(sp)
128+
iterators = [values(T) for T in types]
129+
iterator = constrained_product(iterators, cutoff)
130+
131+
r = collect(I(T) => minsize for T in iterator)
132+
if sp isa GradedSpace
133+
b=TensorKit.SectorDict{I, Int}(r)
134+
return GradedSpace{I, TensorKit.SectorDict{I, Int}}(b, false)
135+
end
136+
return typeof(sp)(r)
137+
end
138+
139+
function constrained_product(iters, cutoff)
140+
N = length(iters)
141+
142+
# 1. Materialize the necessary prefixes (0 to cutoff -> cutoff + 1 elements)
143+
prefixes = [collect(Iterators.take(it, cutoff + 1)) for it in iters]
144+
145+
# 2. Lazy generation via Channel
146+
return Channel() do channel
147+
function recurse(dim, current_sum, current_tuple)
148+
if dim == N
149+
# Remaining budget for the last dimension
150+
remaining = cutoff - current_sum
151+
for i in 0:remaining
152+
# i+1 because Julia is 1-indexed for array access
153+
put!(channel, (current_tuple..., prefixes[dim][i+1]))
154+
end
155+
else
156+
# Each index can be anything from 0 up to the remaining budget
157+
upper_bound = cutoff - current_sum
158+
for i in 0:upper_bound
159+
recurse(dim + 1, current_sum + i, (current_tuple..., prefixes[dim][i+1]))
160+
end
161+
end
162+
end
163+
164+
recurse(1, 0, ())
165+
end
131166
end

src/utility/dynamictruncation.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,6 @@ function updatetruncation(alg::DynamicTruncation; iter::Integer=0, current_rank:
121121

122122
new_maxrank = int_clamp(alg.maxrank, nothing, alg.maxrank_max, rank_factor)
123123
new_maxrank = isnothing(new_maxrank) ? nothing : max(0, new_maxrank - current_rank)
124-
if !iszero(current_rank)
125-
@info "current rank: $current_rank, new maxrank: $new_maxrank"
126-
end
127124

128125
strategy = MatrixAlgebraKit.TruncationStrategy(;
129126
atol = new_atol,

0 commit comments

Comments
 (0)