Skip to content

Commit e5d63ac

Browse files
committed
More updates for GPU friendliness
1 parent 6164c03 commit e5d63ac

4 files changed

Lines changed: 27 additions & 13 deletions

File tree

src/tensors/abstractblocktensor/conversion.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Conversion
22
# ----------
3-
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
3+
4+
# use subtype of TensorMap here to support CuTensorMap
5+
# ROCTensorMap, etc.
6+
function Base.convert(::Type{<:TensorMap}, t::AbstractBlockTensorMap)
47
S = spacetype(t)
58
N₁, N₂ = numout(t), numin(t)
69
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
@@ -22,24 +25,15 @@ function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
2225
indices = getindex.(blockax, Block.(Tuple(k)))
2326
arr_slice = arr[indices...]
2427
# need to check for empty since fusion tree pair might not be present
25-
isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂])
28+
if !isempty(arr_slice)
29+
arr[indices...] .= v[f₁, f₂]
30+
end
2631
end
2732
end
2833

2934
return tdst
3035
end
3136

32-
function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap}
33-
tdst = convert(TensorMap, t)
34-
return convert(T, tdst)
35-
end
36-
37-
# disambiguation
38-
function Base.convert(::Type{TensorMap{T, S, N₁, N₂, A}}, t::AB) where {T, S, N₁, N₂, A, AB <: AbstractBlockTensorMap}
39-
tdst = convert(TensorMap, t)
40-
return convert(T, tdst)
41-
end
42-
4337
function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap}
4438
t isa TT && return t
4539
if t isa AbstractBlockTensorMap

src/tensors/blocktensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ struct BlockTensorMap{TT <: AbstractTensorMap, E, S, N₁, N₂, N} <:
2626
end
2727
end
2828

29+
# seems necessary to dispatch correctly onto the storage type of TT
30+
# AbstractBlockTensorMap doesn't have this TT field
31+
TensorKit.storagetype(::Type{<:BlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
32+
2933
function BlockTensorMap{TT, E, S, N₁, N₂, N}(
3034
::UndefInitializer, space::TensorMapSumSpace{S, N₁, N₂}
3135
) where {TT, E, S, N₁, N₂, N}

src/tensors/sparseblocktensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ function sparseblocktensormaptype(
6363
return SparseBlockTensorMap{TT}
6464
end
6565

66+
# seems necessary to dispatch correctly onto the storage type of TT
67+
# AbstractBlockTensorMap doesn't have this TT field
68+
TensorKit.storagetype(::Type{<:SparseBlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
69+
6670
# Constructors
6771
# ------------
6872
function SparseBlockTensorMap{TT}(

src/tensors/tensoroperations.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ function TO.tensoradd_type(TC, A::AdjointBlockTensorMap, pA::Index2Tuple, conjA:
1515
return TO.tensoradd_type(TC, A', adjointtensorindices(A, pA), !conjA)
1616
end
1717

18+
# copy blocks back to CPU/collect them into an array
19+
# seems necessary for GPU-backed BlockTensorMaps but
20+
# maybe not the most efficient approach?
21+
function TO.tensorscalar(t::AbstractBlockTensorMap{T, S, 0, 0}) where {T, S}
22+
Bs = TK.blocks(t)
23+
B_ends = collect.(map(b -> getfield(b, :blocks), map(last, Bs)))
24+
nz_B_ends = [map(b -> !iszero(b), B) for B in B_ends]
25+
valid_Bs = filter(any, nz_B_ends)
26+
isempty(valid_Bs) && return zero(TK.scalartype(t))
27+
return only(last(first(valid_Bs)))
28+
end
29+
1830
# tensoralloc_contract
1931
# --------------------
2032
for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap)

0 commit comments

Comments
 (0)