Skip to content

Commit 6f44092

Browse files
committed
More updates for GPU friendliness
1 parent 6164c03 commit 6f44092

4 files changed

Lines changed: 26 additions & 12 deletions

File tree

src/tensors/abstractblocktensor/conversion.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Conversion
22
# ----------
3+
34
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
45
S = spacetype(t)
56
N₁, N₂ = numout(t), numin(t)
@@ -22,23 +23,17 @@ function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
2223
indices = getindex.(blockax, Block.(Tuple(k)))
2324
arr_slice = arr[indices...]
2425
# need to check for empty since fusion tree pair might not be present
25-
isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂])
26+
if !isempty(arr_slice)
27+
arr[indices...] .= v[f₁, f₂]
28+
end
2629
end
2730
end
2831

2932
return tdst
3033
end
31-
32-
function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap}
33-
tdst = convert(TensorMap, t)
34-
return convert(T, tdst)
35-
end
36-
37-
# disambiguation
38-
function Base.convert(::Type{TensorMap{T, S, N₁, N₂, A}}, t::AB) where {T, S, N₁, N₂, A, AB <: AbstractBlockTensorMap}
39-
tdst = convert(TensorMap, t)
40-
return convert(T, tdst)
41-
end
34+
# use subtype of TensorMap here to support CuTensorMap
35+
# ROCTensorMap, etc.
36+
Base.convert(::Type{<:TensorMap}, t::AbstractBlockTensorMap) = convert(TensorMap, t)
4237

4338
function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap}
4439
t isa TT && return t

src/tensors/blocktensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ struct BlockTensorMap{TT <: AbstractTensorMap, E, S, N₁, N₂, N} <:
2626
end
2727
end
2828

29+
# seems necessary to dispatch correctly onto the storage type of TT
30+
# AbstractBlockTensorMap doesn't have this TT field
31+
TensorKit.storagetype(::Type{<:BlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
32+
2933
function BlockTensorMap{TT, E, S, N₁, N₂, N}(
3034
::UndefInitializer, space::TensorMapSumSpace{S, N₁, N₂}
3135
) where {TT, E, S, N₁, N₂, N}

src/tensors/sparseblocktensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ function sparseblocktensormaptype(
6363
return SparseBlockTensorMap{TT}
6464
end
6565

66+
# seems necessary to dispatch correctly onto the storage type of TT
67+
# AbstractBlockTensorMap doesn't have this TT field
68+
TensorKit.storagetype(::Type{<:SparseBlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
69+
6670
# Constructors
6771
# ------------
6872
function SparseBlockTensorMap{TT}(

src/tensors/tensoroperations.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ function TO.tensoradd_type(TC, A::AdjointBlockTensorMap, pA::Index2Tuple, conjA:
1515
return TO.tensoradd_type(TC, A', adjointtensorindices(A, pA), !conjA)
1616
end
1717

18+
# copy blocks back to CPU/collect them into an array
19+
# seems necessary for GPU-backed BlockTensorMaps but
20+
# maybe not the most efficient approach?
21+
function TO.tensorscalar(t::AbstractBlockTensorMap{T, S, 0, 0}) where {T, S}
22+
Bs = TK.blocks(t)
23+
B_ends = collect.(map(b -> collect.(getfield(b, :blocks)), map(last, Bs)))
24+
inds = findall(!iszero last, B_ends)
25+
isempty(inds) && return zero(TKscalartype(t))
26+
return only(last(B_ends[only(inds)]))
27+
end
28+
1829
# tensoralloc_contract
1930
# --------------------
2031
for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap)

0 commit comments

Comments
 (0)