Skip to content

Commit 45a19da

Browse files
committed
multithreading is hard -- race conditions are easy...
1 parent 845a945 commit 45a19da

2 files changed

Lines changed: 75 additions & 49 deletions

File tree

src/tensors/indexmanipulations.jl

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -584,59 +584,68 @@ function add_transform_kernel!(
584584
tforeach(fusiontrees(tsrc); scheduler) do (f₁, f₂)
585585
(f₁′, f₂′), coeff = transformer((f₁, f₂))
586586
@inbounds TO.tensoradd!(
587-
tdst[f₁′, f₂′], tsrc[f₁, f₂],
588-
p, false, α * coeff, β, backend, allocator
587+
tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, backend, allocator
589588
)
590589
end
591-
else
592-
cp = TO.allocator_checkpoint!(allocator)
593-
# Non-Abelian fusion: trees sharing the same set of uncoupled (external) sectors
594-
# form a *fusion block* and mix under the transformation via a recoupling matrix U
595-
# (rows = destination trees, columns = source trees). We iterate over blocks.
596-
tforeach(fusionblocks(tsrc); scheduler) do src
597-
dst, U = transformer(src)
598-
if length(src) == 1
599-
# Degenerate block: single tree, U is a 1×1 scalar — skip the buffer + matmul.
600-
(f₁, f₂) = only(fusiontrees(src))
601-
(f₁′, f₂′) = only(fusiontrees(dst))
602-
@inbounds TO.tensoradd!(
603-
tdst[f₁′, f₂′], tsrc[f₁, f₂],
604-
p, false, α * only(U), β, backend, allocator
590+
return nothing
591+
end
592+
cp = TO.allocator_checkpoint!(allocator)
593+
# Non-Abelian fusion: trees sharing the same set of uncoupled (external) sectors
594+
# form a *fusion block* and mix under the transformation via a recoupling matrix U
595+
# (rows = destination trees, columns = source trees). We iterate over blocks.
596+
597+
# buffers have to be created without race condition: err on the side of caution
598+
buffersz = 2 * buffersize(transformer)
599+
generate_buffer = let lock = Threads.ReentrantLock(), allocator = allocator
600+
() -> @lock lock TO.tensoralloc(typeof(data_dst), buffersz, Val(true), allocator)
601+
end
602+
603+
OhMyThreads.@tasks for src in fusionblocks(tsrc)
604+
# setup
605+
OhMyThreads.@set scheduler = scheduler
606+
OhMyThreads.@local buffer = generate_buffer()
607+
608+
dst, U = transformer(src)
609+
610+
if length(src) == 1
611+
# Degenerate block: single tree, U is a 1×1 scalar — skip the buffer + matmul.
612+
(f₁, f₂) = only(fusiontrees(src))
613+
(f₁′, f₂′) = only(fusiontrees(dst))
614+
@inbounds TO.tensoradd!(
615+
tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * only(U), β, backend, allocator
616+
)
617+
else
618+
# Multi-tree block: apply recoupling via a three-step pack → matmul → unpack.
619+
# 1. Extract: flatten each source block into a column of buffer_src
620+
# (shape blocksize × cols), using a trivial permutation so that the
621+
# index layout is canonical before the matmul.
622+
# 2. Recoupling: buffer_dst = buffer_src * U^T (blocksize × rows)
623+
# 3. Insert: scatter columns of buffer_dst to destination blocks,
624+
# applying the actual permutation p in the same step.
625+
rows, cols = size(U)
626+
sz_src = size(tsrc[first(fusiontrees(src))...])
627+
blocksize = prod(sz_src)
628+
ptriv = (ntuple(identity, length(sz_src)), ())
629+
buffer_dst = StridedView(buffer, (blocksize, rows), (1, blocksize), 0)
630+
buffer_src = StridedView(buffer, (blocksize, cols), (1, blocksize), blocksize * rows)
631+
@inbounds for (i, (f₁, f₂)) in enumerate(fusiontrees(src))
632+
TO.tensoradd!(
633+
sreshape(buffer_src[:, i], sz_src), tsrc[f₁, f₂],
634+
ptriv, false, One(), Zero(), backend, allocator
605635
)
606-
else
607-
# Multi-tree block: apply recoupling via a three-step pack → matmul → unpack.
608-
# 1. Extract: flatten each source block into a column of buffer_src
609-
# (shape blocksize × cols), using a trivial permutation so that the
610-
# index layout is canonical before the matmul.
611-
# 2. Recoupling: buffer_dst = buffer_src * U^T (blocksize × rows)
612-
# 3. Insert: scatter columns of buffer_dst to destination blocks,
613-
# applying the actual permutation p in the same step.
614-
rows, cols = size(U)
615-
sz_src = size(tsrc[first(fusiontrees(src))...])
616-
blocksize = prod(sz_src)
617-
ptriv = (ntuple(identity, length(sz_src)), ())
618-
buffer = TO.tensoralloc(storagetype(tdst), blocksize * (rows + cols), Val(true), allocator)
619-
buffer_dst = StridedView(buffer, (blocksize, rows), (1, blocksize), 0)
620-
buffer_src = StridedView(buffer, (blocksize, cols), (1, blocksize), blocksize * rows)
621-
@inbounds for (i, (f₁, f₂)) in enumerate(fusiontrees(src))
622-
TO.tensoradd!(
623-
sreshape(buffer_src[:, i], sz_src), tsrc[f₁, f₂],
624-
ptriv, false, One(), Zero(), backend, allocator
625-
)
626-
end
627-
U′ = adapt_transformer(U, storagetype(tdst))
628-
mul!(buffer_dst, buffer_src, transpose(StridedView(U′)))
629-
@inbounds for (i, (f₃, f₄)) in enumerate(fusiontrees(dst))
630-
TO.tensoradd!(
631-
tdst[f₃, f₄], sreshape(buffer_dst[:, i], sz_src),
632-
p, false, α, β, backend, allocator
633-
)
634-
end
635-
TO.tensorfree!(buffer, allocator)
636636
end
637+
U′ = adapt_transformer(U, storagetype(tdst))
638+
mul!(buffer_dst, buffer_src, transpose(StridedView(U′)))
639+
@inbounds for (i, (f₃, f₄)) in enumerate(fusiontrees(dst))
640+
TO.tensoradd!(
641+
tdst[f₃, f₄], sreshape(buffer_dst[:, i], sz_src),
642+
p, false, α, β, backend, allocator
643+
)
644+
end
645+
TO.tensorfree!(buffer, allocator)
637646
end
638-
TO.allocator_reset!(allocator, cp)
639647
end
648+
TO.allocator_reset!(allocator, cp)
640649
return nothing
641650
end
642651

@@ -667,7 +676,19 @@ function add_transform_kernel!(
667676
# sz_{dst,src} — array shape of each block (same for all trees in the block)
668677
# structs_{dst,src}[i] — (offset, strides) into the flat data vector for tree i
669678
cp = TO.allocator_checkpoint!(allocator)
670-
tforeach(transformer.data; scheduler) do (U, (sz_dst, structs_dst), (sz_src, structs_src))
679+
680+
# buffers have to be created without race condition: err on the side of caution
681+
buffersz = 2 * buffersize(transformer)
682+
generate_buffer = let lock = Threads.ReentrantLock(), allocator = allocator
683+
() -> @lock lock TO.tensoralloc(typeof(data_dst), buffersz, Val(true), allocator)
684+
end
685+
686+
OhMyThreads.@tasks for subtransformer in transformer.data
687+
# setup
688+
OhMyThreads.@set scheduler = scheduler
689+
OhMyThreads.@local buffer = generate_buffer()
690+
U, (sz_dst, structs_dst), (sz_src, structs_src) = subtransformer
691+
671692
if length(U) == 1
672693
# Degenerate block with a single tree: no matmul needed.
673694
coeff = only(U)
@@ -682,7 +703,6 @@ function add_transform_kernel!(
682703
rows, cols = size(U)
683704
blocksize = prod(sz_src)
684705
ptriv = (ntuple(identity, length(sz_src)), ())
685-
buffer = TO.tensoralloc(typeof(data_dst), blocksize * (rows + cols), Val(true), allocator)
686706
buffer_dst = StridedView(buffer, (blocksize, rows), (1, blocksize), 0)
687707
buffer_src = StridedView(buffer, (blocksize, cols), (1, blocksize), blocksize * rows)
688708

src/tensors/treetransformers.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,9 @@ end
203203
function _transformer_weight((mat, structs_dst, structs_src)::GenericTransformerData)
204204
return length(mat) * prod(structs_dst[1])
205205
end
206+
207+
function buffersize(transformer::GenericTreeTransformer)
208+
return maximum(transformer.data; init = 0) do (basistransform, structures_dst, _)
209+
return prod(structures_dst[1]) * size(basistransform, 1)
210+
end
211+
end

0 commit comments

Comments
 (0)