Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BlockTensorKit"
uuid = "5f87ffc2-9cf1-4a46-8172-465d160bd8cd"
version = "0.3.5"
version = "0.3.6"
authors = ["Lukas Devos <ldevos98@gmail.com> and contributors"]

[deps]
Expand All @@ -25,7 +25,7 @@ MatrixAlgebraKit = "0.6"
Random = "1"
SafeTestsets = "0.1"
Strided = "2"
TensorKit = "0.16"
TensorKit = "0.16.1"
TensorOperations = "5"
Test = "1"
TestExtras = "0.2, 0.3"
Expand Down
17 changes: 17 additions & 0 deletions src/linalg/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ function LinearAlgebra.mul!(C::BlockTensorMap, α::Number, A::BlockTensorMap)
return C
end

for TA in (:AbstractBlockTensorMap, :(TK.DiagonalTensorMap), :(TK.AdjointTensorMap), :TensorMap),
TB in (:AbstractBlockTensorMap, :(TK.DiagonalTensorMap), :(TK.AdjointTensorMap), :TensorMap)
(TA === :AbstractBlockTensorMap || TB === :AbstractBlockTensorMap) || continue
@eval function TK.compose_dest(A::$TA, B::$TB)
S = TK.check_spacetype(A, B)
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
M = TK.promote_storagetype(TK.similarstoragetype(A, TC), TK.similarstoragetype(B, TC))
TTC = if issparse(A) && issparse(B)
sparseblocktensormaptype(S, numout(A), numin(B), M)
else
blocktensormaptype(S, numout(A), numin(B), M)
end
structure = codomain(A) ← domain(B)
return TO.tensoralloc(TTC, structure, Val(false))
end
end

# This is a generic implementation of `mul!` for BlockTensors that is used to make it easier
# to work with abstract element types, that might not support in-place operations.
# For now, the implementation might not be hyper-optimized, but the assumption is that we
Expand Down
185 changes: 21 additions & 164 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
@noinline function _check_spacetype(
::Type{S₁}, ::Type{S₂}
) where {S₁ <: ElementarySpace, S₂ <: ElementarySpace}
S₁ === S₂ ||
S₁ === SumSpace{S₂} ||
SumSpace{S₁} === S₂ ||
throw(SpaceMismatch(lazy"incompatible spacetypes: $S₁ and $S₂"))
return nothing
end

# TensorOperations
# ----------------
function TO.tensoradd_type(
TC, A::AbstractBlockTensorMap, ::Index2Tuple{N₁, N₂}, ::Bool
) where {N₁, N₂}
TA = eltype(A)
I = sectortype(A)
Tnew = sectorscalartype(I) <: Real ? TC : complex(TC)
if TA isa Union
M = Union{TK.similarstoragetype(TA.a, Tnew), TK.similarstoragetype(TA.b, Tnew)}
else
M = TK.similarstoragetype(TA, Tnew)
end
S = spacetype(A)
M = TK.similarstoragetype(A, TK.promote_permute(TC, sectortype(S)))
return if issparse(A)
sparseblocktensormaptype(spacetype(A), N₁, N₂, M)
sparseblocktensormaptype(S, N₁, N₂, M)
else
blocktensormaptype(spacetype(A), N₁, N₂, M)
blocktensormaptype(S, N₁, N₂, M)
end
end
function TO.tensoradd_type(TC, A::AdjointBlockTensorMap, pA::Index2Tuple, conjA::Bool)
Expand All @@ -33,58 +17,22 @@ end

# tensoralloc_contract
# --------------------
function TO.tensorcontract_type(
TC,
A::AbstractBlockTensorMap, ::Index2Tuple, ::Bool,
B::AbstractBlockTensorMap, ::Index2Tuple, ::Bool,
::Index2Tuple{N₁, N₂},
) where {N₁, N₂}
_check_spacetype(spacetype(A), spacetype(B))

I = sectortype(A)
Tnew = sectorscalartype(I) <: Real ? TC : complex(TC)
M = promote_storagetype(Tnew, eltype(A), eltype(B))

return if issparse(A) && issparse(B)
sparseblocktensormaptype(spacetype(A), N₁, N₂, M)
else
blocktensormaptype(spacetype(A), N₁, N₂, M)
end
end
function TO.tensorcontract_type(
TC,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
B::AbstractBlockTensorMap, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple{N₁, N₂},
) where {N₁, N₂}
_check_spacetype(spacetype(A), spacetype(B))

I = sectortype(A)
Tnew = sectorscalartype(I) <: Real ? TC : complex(TC)
M = promote_storagetype(Tnew, typeof(A), eltype(B))

return if issparse(A) && issparse(B)
sparseblocktensormaptype(spacetype(A), N₁, N₂, M)
else
blocktensormaptype(spacetype(A), N₁, N₂, M)
end
end
function TO.tensorcontract_type(
TC,
A::AbstractBlockTensorMap, ::Index2Tuple, ::Bool,
B::AbstractTensorMap, ::Index2Tuple, ::Bool,
::Index2Tuple{N₁, N₂},
) where {N₁, N₂}
_check_spacetype(spacetype(A), spacetype(B))

I = sectortype(A)
Tnew = sectorscalartype(I) <: Real ? TC : complex(TC)
M = promote_storagetype(Tnew, eltype(A), typeof(B))

return if issparse(A) && issparse(B)
sparseblocktensormaptype(spacetype(A), N₁, N₂, M)
else
blocktensormaptype(spacetype(A), N₁, N₂, M)
for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap)
TTA == TTB == :AbstractTensorMap && continue
@eval function TO.tensorcontract_type(
TC,
A::$TTA, ::Index2Tuple, ::Bool,
B::$TTB, ::Index2Tuple, ::Bool,
::Index2Tuple{N₁, N₂},
) where {N₁, N₂}
S = TK.check_spacetype(A, B)
TC′ = TK.promote_permute(TC, sectortype(S))
M = TK.promote_storagetype(TK.similarstoragetype(A, TC′), TK.similarstoragetype(B, TC′))
return if issparse(A) && issparse(B)
sparseblocktensormaptype(S, N₁, N₂, M)
else
blocktensormaptype(S, N₁, N₂, M)
end
end
end

Expand Down Expand Up @@ -117,23 +65,6 @@ function TO.tensoralloc_contract(
end
end

function promote_storagetype(::Type{T}, ::Type{T₁}, ::Type{T₂}) where {T, T₁, T₂}
if T₁ isa Union
M₁ = Union{TK.similarstoragetype(T₁.a, T), TK.similarstoragetype(T₁.b, T)}
else
M₁ = TK.similarstoragetype(T₁, T)
end
if T₂ isa Union
M₂ = Union{TK.similarstoragetype(T₂.a, T), TK.similarstoragetype(T₂.b, T)}
else
M₂ = TK.similarstoragetype(T₂, T)
end
return Union{M₁, M₂}
end

# EVIL HACK!!!
TK.storagetype(::Type{AbstractTensorMap{TT, S, N₁, N₂}}) where {TT, S, N₁, N₂} = Vector{TT}

function promote_blocktype(::Type{TT}, ::Type{A₁}, ::Type{A₂}) where {TT, A₁, A₂}
N = similarblocktype(A₁, TT)
@assert N === similarblocktype(A₂, TT) "incompatible block types"
Expand All @@ -154,80 +85,6 @@ function TO.tensoralloc(
return C
end

# unfortunate overlaod until TK fix
function TK.blas_contract!(
C::AbstractBlockTensorMap,
A::AbstractTensorMap, pA::Index2Tuple,
B::AbstractTensorMap, pB::Index2Tuple,
pAB::Index2Tuple, α, β,
backend, allocator
)
bstyle = BraidingStyle(sectortype(C))
bstyle isa SymmetricBraiding ||
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
TC = scalartype(C)

# check which tensors have to be permuted/copied
copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC)
copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC)

if bstyle isa Fermionic && any(isdual ∘ Base.Fix1(space, B), pB[1])
# twist smallest object if neither or both already have to be permuted
# otherwise twist the one that already is copied
if !(copyA ⊻ copyB)
twistA = dim(A) < dim(B)
else
twistA = copyA
end
twistB = !twistA
copyA |= twistA
copyB |= twistB
else
twistA = false
twistB = false
end

# Bring A in the correct form for BLAS contraction
if copyA
Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
twistA && twist!(Anew, filter(!isdual ∘ Base.Fix1(space, Anew), domainind(Anew)))
else
Anew = permute(A, pA)
end
pAnew = (codomainind(Anew), domainind(Anew))

# Bring B in the correct form for BLAS contraction
if copyB
Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator)
Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator)
twistB && twist!(Bnew, filter(isdual ∘ Base.Fix1(space, Bnew), codomainind(Bnew)))
else
Bnew = permute(B, pB)
end
pBnew = (codomainind(Bnew), domainind(Bnew))

# Bring C in the correct form for BLAS contraction
ipAB = TO.oindABinC(pAB, pAnew, pBnew)
copyC = !TO.isblasdestination(C, ipAB)

if copyC
Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
mul!(Cnew, Anew, Bnew)
TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator)
TO.tensorfree!(Cnew, allocator)
else
Cnew = permute(C, ipAB)
mul!(Cnew, Anew, Bnew, α, β)
end

copyA && TO.tensorfree!(Anew, allocator)
copyB && TO.tensorfree!(Bnew, allocator)

return C
end


# tensorfree!
# -----------
function TO.tensorfree!(t::BlockTensorMap, allocator = TO.DefaultAllocator())
Expand All @@ -248,7 +105,7 @@ function TK.trace_permute!(
backend::AbstractBackend = TO.DefaultBackend(),
)
# some input checks
_check_spacetype(spacetype(tdst), spacetype(tsrc))
TK.check_spacetype(tdst, tsrc)
if !(BraidingStyle(sectortype(tdst)) isa SymmetricBraiding)
throw(
SectorMismatch(
Expand Down
22 changes: 22 additions & 0 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ end

# add
# ---
function VI.add(ty::AbstractBlockTensorMap, tx::AbstractBlockTensorMap, α::Number, β::Number)
S = TK.check_spacetype(ty, tx)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))"))

# result type defaults to TensorMap if the types don't match to avoid assymmetric
# implementation via zerovector(ty, T) vs zerovector(tx, T)
# This would give issues for example with DiagonalTensorMap + TensorMap
T = VectorInterface.promote_add(ty, tx, α, β)
tdst = if typeof(ty) === typeof(tx)
zerovector(ty, T)
else
M = TK.promote_storagetype(TK.similarstoragetype(ty, T), TK.similarstoragetype(tx, T))
if issparse(ty) && issparse(tx)
sparseblocktensormaptype(S, numout(ty), numin(ty), M)(undef, space(ty))
else
blocktensormaptype(S, numout(ty), numin(ty), M)(undef, space(ty))
end
end

return add!(scale!(tdst, ty, β), tx, α)
end

function VI.add!(ty::BlockTensorMap, tx::BlockTensorMap, α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))"))
add!(parent(ty), parent(tx), α, β)
Expand Down