Skip to content

Commit 6fe776e

Browse files
authored
optimize converter to TensorMap (#40)
* optimize converter to `TensorMap` * fix for sparse * bump v0.3.3 [skip ci]
1 parent d2ef909 commit 6fe776e

2 files changed

Lines changed: 25 additions & 17 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BlockTensorKit"
22
uuid = "5f87ffc2-9cf1-4a46-8172-465d160bd8cd"
3-
version = "0.3.2"
3+
version = "0.3.3"
44
authors = ["Lukas Devos <ldevos98@gmail.com> and contributors"]
55

66
[deps]

src/tensors/abstractblocktensor/conversion.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
# Conversion
22
# ----------
3-
function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap}
4-
cod = ProductSpace{spacetype(t), numout(t)}(oplus.(codomain(t).spaces))
5-
dom = ProductSpace{spacetype(t), numin(t)}(oplus.(domain(t).spaces))
6-
3+
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
4+
S = spacetype(t)
5+
N₁, N₂ = numout(t), numin(t)
6+
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
7+
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
78
tdst = similar(t, cod dom)
8-
for (f₁, f₂) in fusiontrees(tdst)
9-
tdst[f₁, f₂] .= t[f₁, f₂]
10-
end
119

12-
return convert(T, tdst)
13-
end
14-
# disambiguate
15-
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
16-
cod = ProductSpace{spacetype(t), numout(t)}(oplus.(codomain(t).spaces))
17-
dom = ProductSpace{spacetype(t), numin(t)}(oplus.(domain(t).spaces))
10+
issparse(t) && zerovector!(tdst)
1811

19-
tdst = similar(t, cod dom)
20-
for (f₁, f₂) in fusiontrees(tdst)
21-
copyto!(tdst[f₁, f₂], t[f₁, f₂])
12+
for ((f₁, f₂), arr) in subblocks(tdst)
13+
blockax = ntuple(N₁ + N₂) do i
14+
return if i <= N₁
15+
blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(t, i)))
16+
else
17+
blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(t, i)'))
18+
end
19+
end
20+
21+
for (k, v) in nonzero_pairs(t)
22+
indices = getindex.(blockax, Block.(Tuple(k)))
23+
copy!(arr[indices...], v[f₁, f₂])
24+
end
2225
end
2326

2427
return tdst
2528
end
2629

30+
function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap}
31+
tdst = convert(TensorMap, t)
32+
return convert(T, tdst)
33+
end
34+
2735
function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap}
2836
t isa TT && return t
2937
if t isa AbstractBlockTensorMap

0 commit comments

Comments
 (0)