@@ -624,7 +624,8 @@ function add_transform_kernel!(
624624 ptriv, false , One (), Zero (), backend, allocator
625625 )
626626 end
627- mul! (buffer_dst, buffer_src, transpose (StridedView (U)))
627+ U′ = adapt_transformer (U, storagetype (tdst))
628+ mul! (buffer_dst, buffer_src, transpose (StridedView (U′)))
628629 @inbounds for (i, (f₃, f₄)) in enumerate (fusiontrees (dst))
629630 TO. tensoradd! (
630631 tdst[f₃, f₄], sreshape (buffer_dst[:, i], sz_src),
@@ -661,7 +662,6 @@ function add_transform_kernel!(
661662 data_dst:: DenseVector , data_src:: DenseVector , p, transformer:: GenericTreeTransformer ,
662663 α, β, backend, allocator, scheduler
663664 )
664- transformer = adapt_transformer (transformer, data_dst)
665665 # Each entry covers one fusion block:
666666 # U — recoupling matrix (rows = dst trees, cols = src trees)
667667 # sz_{dst,src} — array shape of each block (same for all trees in the block)
@@ -697,7 +697,8 @@ function add_transform_kernel!(
697697
698698 # 2. Recoupling: buffer_dst = buffer_src * U^T (each output tree is a linear
699699 # combination of input trees weighted by the recoupling coefficients).
700- mul! (buffer_dst, buffer_src, transpose (StridedView (U)))
700+ U′ = adapt_transformer (U, typeof (data_dst))
701+ mul! (buffer_dst, buffer_src, transpose (StridedView (U′)))
701702
702703 # 3. Insert: scatter column i of buffer_dst into the destination, applying the
703704 # actual index permutation p in the same tensoradd! call.
@@ -713,3 +714,12 @@ function add_transform_kernel!(
713714 TO. allocator_reset! (allocator, cp)
714715 return nothing
715716end
717+
718+ """
719+ adapt_transformer(U::AbstractMatrix, ::Type{A})
720+
721+ Return a version of the basis transformation `U` that is compatible for storage type `A`.
722+ Default is a no-op.
723+ Backends (e.g. CUDA, AMDGPU) should overload this for their vector types to ensure the recoupling matrix `U` is on the correct device.
724+ """
725+ adapt_transformer (U:: AbstractMatrix , :: Type{A} ) where {A} = U
0 commit comments