@@ -584,59 +584,68 @@ function add_transform_kernel!(
584584 tforeach (fusiontrees (tsrc); scheduler) do (f₁, f₂)
585585 (f₁′, f₂′), coeff = transformer ((f₁, f₂))
586586 @inbounds TO. tensoradd! (
587- tdst[f₁′, f₂′], tsrc[f₁, f₂],
588- p, false , α * coeff, β, backend, allocator
587+ tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β, backend, allocator
589588 )
590589 end
591- else
592- cp = TO. allocator_checkpoint! (allocator)
593- # Non-Abelian fusion: trees sharing the same set of uncoupled (external) sectors
594- # form a *fusion block* and mix under the transformation via a recoupling matrix U
595- # (rows = destination trees, columns = source trees). We iterate over blocks.
596- tforeach (fusionblocks (tsrc); scheduler) do src
597- dst, U = transformer (src)
598- if length (src) == 1
599- # Degenerate block: single tree, U is a 1×1 scalar — skip the buffer + matmul.
600- (f₁, f₂) = only (fusiontrees (src))
601- (f₁′, f₂′) = only (fusiontrees (dst))
602- @inbounds TO. tensoradd! (
603- tdst[f₁′, f₂′], tsrc[f₁, f₂],
604- p, false , α * only (U), β, backend, allocator
590+ return nothing
591+ end
592+ cp = TO. allocator_checkpoint! (allocator)
593+ # Non-Abelian fusion: trees sharing the same set of uncoupled (external) sectors
594+ # form a *fusion block* and mix under the transformation via a recoupling matrix U
595+ # (rows = destination trees, columns = source trees). We iterate over blocks.
596+
597+ # buffers have to be created without race condition: err on the side of caution
598+ buffersz = 2 * buffersize (transformer)
599+ generate_buffer = let lock = Threads. ReentrantLock (), allocator = allocator
600+ () -> @lock lock TO. tensoralloc (typeof (data_dst), buffersz, Val (true ), allocator)
601+ end
602+
603+ OhMyThreads. @tasks for src in fusionblocks (tsrc)
604+ # setup
605+ OhMyThreads. @set scheduler = scheduler
606+ OhMyThreads. @local buffer = generate_buffer ()
607+
608+ dst, U = transformer (src)
609+
610+ if length (src) == 1
611+ # Degenerate block: single tree, U is a 1×1 scalar — skip the buffer + matmul.
612+ (f₁, f₂) = only (fusiontrees (src))
613+ (f₁′, f₂′) = only (fusiontrees (dst))
614+ @inbounds TO. tensoradd! (
615+ tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * only (U), β, backend, allocator
616+ )
617+ else
618+ # Multi-tree block: apply recoupling via a three-step pack → matmul → unpack.
619+ # 1. Extract: flatten each source block into a column of buffer_src
620+ # (shape blocksize × cols), using a trivial permutation so that the
621+ # index layout is canonical before the matmul.
622+ # 2. Recoupling: buffer_dst = buffer_src * U^T (blocksize × rows)
623+ # 3. Insert: scatter columns of buffer_dst to destination blocks,
624+ # applying the actual permutation p in the same step.
625+ rows, cols = size (U)
626+ sz_src = size (tsrc[first (fusiontrees (src))... ])
627+ blocksize = prod (sz_src)
628+ ptriv = (ntuple (identity, length (sz_src)), ())
629+ buffer_dst = StridedView (buffer, (blocksize, rows), (1 , blocksize), 0 )
630+ buffer_src = StridedView (buffer, (blocksize, cols), (1 , blocksize), blocksize * rows)
631+ @inbounds for (i, (f₁, f₂)) in enumerate (fusiontrees (src))
632+ TO. tensoradd! (
633+ sreshape (buffer_src[:, i], sz_src), tsrc[f₁, f₂],
634+ ptriv, false , One (), Zero (), backend, allocator
605635 )
606- else
607- # Multi-tree block: apply recoupling via a three-step pack → matmul → unpack.
608- # 1. Extract: flatten each source block into a column of buffer_src
609- # (shape blocksize × cols), using a trivial permutation so that the
610- # index layout is canonical before the matmul.
611- # 2. Recoupling: buffer_dst = buffer_src * U^T (blocksize × rows)
612- # 3. Insert: scatter columns of buffer_dst to destination blocks,
613- # applying the actual permutation p in the same step.
614- rows, cols = size (U)
615- sz_src = size (tsrc[first (fusiontrees (src))... ])
616- blocksize = prod (sz_src)
617- ptriv = (ntuple (identity, length (sz_src)), ())
618- buffer = TO. tensoralloc (storagetype (tdst), blocksize * (rows + cols), Val (true ), allocator)
619- buffer_dst = StridedView (buffer, (blocksize, rows), (1 , blocksize), 0 )
620- buffer_src = StridedView (buffer, (blocksize, cols), (1 , blocksize), blocksize * rows)
621- @inbounds for (i, (f₁, f₂)) in enumerate (fusiontrees (src))
622- TO. tensoradd! (
623- sreshape (buffer_src[:, i], sz_src), tsrc[f₁, f₂],
624- ptriv, false , One (), Zero (), backend, allocator
625- )
626- end
627- U′ = adapt_transformer (U, storagetype (tdst))
628- mul! (buffer_dst, buffer_src, transpose (StridedView (U′)))
629- @inbounds for (i, (f₃, f₄)) in enumerate (fusiontrees (dst))
630- TO. tensoradd! (
631- tdst[f₃, f₄], sreshape (buffer_dst[:, i], sz_src),
632- p, false , α, β, backend, allocator
633- )
634- end
635- TO. tensorfree! (buffer, allocator)
636636 end
637+ U′ = adapt_transformer (U, storagetype (tdst))
638+ mul! (buffer_dst, buffer_src, transpose (StridedView (U′)))
639+ @inbounds for (i, (f₃, f₄)) in enumerate (fusiontrees (dst))
640+ TO. tensoradd! (
641+ tdst[f₃, f₄], sreshape (buffer_dst[:, i], sz_src),
642+ p, false , α, β, backend, allocator
643+ )
644+ end
645+ TO. tensorfree! (buffer, allocator)
637646 end
638- TO. allocator_reset! (allocator, cp)
639647 end
648+ TO. allocator_reset! (allocator, cp)
640649 return nothing
641650end
642651
@@ -667,7 +676,19 @@ function add_transform_kernel!(
667676 # sz_{dst,src} — array shape of each block (same for all trees in the block)
668677 # structs_{dst,src}[i] — (offset, strides) into the flat data vector for tree i
669678 cp = TO. allocator_checkpoint! (allocator)
670- tforeach (transformer. data; scheduler) do (U, (sz_dst, structs_dst), (sz_src, structs_src))
679+
680+ # buffers have to be created without race condition: err on the side of caution
681+ buffersz = 2 * buffersize (transformer)
682+ generate_buffer = let lock = Threads. ReentrantLock (), allocator = allocator
683+ () -> @lock lock TO. tensoralloc (typeof (data_dst), buffersz, Val (true ), allocator)
684+ end
685+
686+ OhMyThreads. @tasks for subtransformer in transformer. data
687+ # setup
688+ OhMyThreads. @set scheduler = scheduler
689+ OhMyThreads. @local buffer = generate_buffer ()
690+ U, (sz_dst, structs_dst), (sz_src, structs_src) = subtransformer
691+
671692 if length (U) == 1
672693 # Degenerate block with a single tree: no matmul needed.
673694 coeff = only (U)
@@ -682,7 +703,6 @@ function add_transform_kernel!(
682703 rows, cols = size (U)
683704 blocksize = prod (sz_src)
684705 ptriv = (ntuple (identity, length (sz_src)), ())
685- buffer = TO. tensoralloc (typeof (data_dst), blocksize * (rows + cols), Val (true ), allocator)
686706 buffer_dst = StridedView (buffer, (blocksize, rows), (1 , blocksize), 0 )
687707 buffer_src = StridedView (buffer, (blocksize, cols), (1 , blocksize), blocksize * rows)
688708
0 commit comments