Skip to content

Commit c4408b3

Browse files
committed
update adapt_transformer
1 parent 34d74c2 commit c4408b3

4 files changed

Lines changed: 15 additions & 28 deletions

File tree

ext/TensorKitAMDGPUExt/roctensormap.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,4 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
163163
end
164164
end
165165

166-
function TensorKit.adapt_transformer(
167-
t::TensorKit.GenericTreeTransformer, data::ROCVector
168-
)
169-
new_data = map(t.data) do (U, structs_dst, structs_src)
170-
return AMDGPU.Adapt.adapt(ROCArray, U), structs_dst, structs_src
171-
end
172-
return TensorKit.GenericTreeTransformer(new_data)
173-
end
166+
TensorKit.adapt_transformer(U::AbstractMatrix, ::Type{A}) where {A <: ROCVector} = AMDGPU.Adapt.adapt(ROCArray, U)

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,4 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
169169
end
170170
end
171171

172-
function TensorKit.adapt_transformer(
173-
t::TensorKit.GenericTreeTransformer, data::CuVector
174-
)
175-
new_data = map(t.data) do (U, structs_dst, structs_src)
176-
return CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src
177-
end
178-
return TensorKit.GenericTreeTransformer(new_data)
179-
end
172+
TensorKit.adapt_transformer(U::AbstractMatrix, ::Type{A}) where {A <: CuVector} = CUDA.Adapt.adapt(CuArray, U)

src/tensors/indexmanipulations.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ function add_transform_kernel!(
624624
ptriv, false, One(), Zero(), backend, allocator
625625
)
626626
end
627-
mul!(buffer_dst, buffer_src, transpose(StridedView(U)))
627+
U′ = adapt_transformer(U, storagetype(tdst))
628+
mul!(buffer_dst, buffer_src, transpose(StridedView(U′)))
628629
@inbounds for (i, (f₃, f₄)) in enumerate(fusiontrees(dst))
629630
TO.tensoradd!(
630631
tdst[f₃, f₄], sreshape(buffer_dst[:, i], sz_src),
@@ -661,7 +662,6 @@ function add_transform_kernel!(
661662
data_dst::DenseVector, data_src::DenseVector, p, transformer::GenericTreeTransformer,
662663
α, β, backend, allocator, scheduler
663664
)
664-
transformer = adapt_transformer(transformer, data_dst)
665665
# Each entry covers one fusion block:
666666
# U — recoupling matrix (rows = dst trees, cols = src trees)
667667
# sz_{dst,src} — array shape of each block (same for all trees in the block)
@@ -697,7 +697,8 @@ function add_transform_kernel!(
697697

698698
# 2. Recoupling: buffer_dst = buffer_src * U^T (each output tree is a linear
699699
# combination of input trees weighted by the recoupling coefficients).
700-
mul!(buffer_dst, buffer_src, transpose(StridedView(U)))
700+
U′ = adapt_transformer(U, typeof(data_dst))
701+
mul!(buffer_dst, buffer_src, transpose(StridedView(U′)))
701702

702703
# 3. Insert: scatter column i of buffer_dst into the destination, applying the
703704
# actual index permutation p in the same tensoradd! call.
@@ -713,3 +714,12 @@ function add_transform_kernel!(
713714
TO.allocator_reset!(allocator, cp)
714715
return nothing
715716
end
717+
718+
"""
719+
adapt_transformer(U::AbstractMatrix, ::Type{A})
720+
721+
Return a version of the basis transformation `U` that is compatible for storage type `A`.
722+
Default is a no-op.
723+
Backends (e.g. CUDA, AMDGPU) should overload this for their vector types to ensure the recoupling matrix `U` is on the correct device.
724+
"""
725+
adapt_transformer(U::AbstractMatrix, ::Type{A}) where {A} = U

src/tensors/treetransformers.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,3 @@ end
203203
function _transformer_weight((mat, structs_dst, structs_src)::GenericTransformerData)
204204
return length(mat) * prod(structs_dst[1])
205205
end
206-
207-
"""
208-
adapt_transformer(transformer::TreeTransformer, data::AbstractVector)
209-
210-
Return a version of `transformer` whose internal arrays are compatible with `data`.
211-
Default is a no-op. Backends (e.g. CUDA, AMDGPU) should overload this for their vector types
212-
to ensure the recoupling matrix `U` inside `GenericTreeTransformer` is on the correct device.
213-
"""
214-
adapt_transformer(t::TreeTransformer, ::AbstractVector) = t

0 commit comments

Comments
 (0)