diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 7c349c885..acf28c4d4 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -59,7 +59,9 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) structure_dst = fusionblockstructure(Vdst) + fusionstructure_dst = structure_dst.fusiontreestructure structure_src = fusionblockstructure(Vsrc) + fusionstructure_src = structure_src.fusiontreestructure I = sectortype(Vsrc) uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂) @@ -78,15 +80,15 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) # TODO: this can be multithreaded for (i, uncoupled) in enumerate(uncoupleds_src_unique) - ids_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = structure_src.fusiontreelist[ids_src] + inds_src = findall(==(uncoupled), uncoupleds_src) + fusiontrees_outer_src = structure_src.fusiontreelist[inds_src] uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) - ids_dst = findall(==(uncoupled_dst), uncoupleds_dst) + inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) - fusiontrees_outer_dst = structure_dst.fusiontreelist[ids_dst] + fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst] - matrix = zeros(sectorscalartype(I), length(ids_dst), length(ids_src)) + matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) for ((f₃, f₄), coeff) in transform(f₁, f₂) col = findfirst(==((f₃, f₄)), fusiontrees_outer_dst)::Int @@ -94,13 +96,10 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) end end - structs_src = structure_src.fusiontreestructure[ids_src] - sz_src = structs_src[1][1] - newstructs_src = map(x -> (x[2], x[3]), structs_src) - - structs_dst = structure_dst.fusiontreestructure[ids_dst] - sz_dst = structs_dst[1][1] - newstructs_dst = map(x -> (x[2], x[3]), structs_dst) + # size is shared between blocks, so repack: + # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...]) + sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, inds_src) + sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst) @debug("Created recoupling block for uncoupled: $uncoupled", sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix)) @@ -124,6 +123,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) return transformer end +function repack_transformer_structure(structures, ids) + sz = structures[first(ids)][1] + strides_offsets = map(i -> (structures[i][2], structures[i][3]), ids) + return sz, strides_offsets +end + function buffersize(transformer::GenericTreeTransformer) return maximum(transformer.data; init=0) do (basistransform, structures_dst, _) return prod(structures_dst[1]) * size(basistransform, 1)