Skip to content

Commit 87013c1

Browse files
committed
add hook for adapt_transformer
1 parent 2841304 commit 87013c1

4 files changed

Lines changed: 26 additions & 2 deletions

File tree

ext/TensorKitAMDGPUExt/roctensormap.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,12 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
162162
return tf
163163
end
164164
end
165+
166+
function TensorKit.adapt_transformer(
167+
t::TensorKit.GenericTreeTransformer, data::ROCVector
168+
)
169+
new_data = map(t.data) do (U, structs_dst, structs_src)
170+
return AMDGPU.Adapt.adapt(ROCArray, U), structs_dst, structs_src
171+
end
172+
return TensorKit.GenericTreeTransformer(new_data)
173+
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
169169
end
170170
end
171171

172-
function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS}
173-
return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...)
172+
function TensorKit.adapt_transformer(
173+
t::TensorKit.GenericTreeTransformer, data::CuVector
174+
)
175+
new_data = map(t.data) do (U, structs_dst, structs_src)
176+
return CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src
177+
end
178+
return TensorKit.GenericTreeTransformer(new_data)
174179
end

src/tensors/indexmanipulations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ function add_transform_kernel!(
661661
data_dst::DenseVector, data_src::DenseVector, p, transformer::GenericTreeTransformer,
662662
α, β, backend, allocator, scheduler
663663
)
664+
transformer = adapt_transformer(transformer, data_dst)
664665
# Each entry covers one fusion block:
665666
# U — recoupling matrix (rows = dst trees, cols = src trees)
666667
# sz_{dst,src} — array shape of each block (same for all trees in the block)

src/tensors/treetransformers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,12 @@ end
203203
function _transformer_weight((mat, structs_dst, structs_src)::GenericTransformerData)
204204
return length(mat) * prod(structs_dst[1])
205205
end
206+
207+
"""
208+
adapt_transformer(transformer::TreeTransformer, data::AbstractVector)
209+
210+
Return a version of `transformer` whose internal arrays are compatible with `data`.
211+
Default is a no-op. Backends (e.g. CUDA, AMDGPU) should overload this for their vector types
212+
to ensure the recoupling matrix `U` inside `GenericTreeTransformer` is on the correct device.
213+
"""
214+
adapt_transformer(t::TreeTransformer, ::AbstractVector) = t

0 commit comments

Comments
 (0)