Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
Expand Down Expand Up @@ -37,6 +38,7 @@ LinearAlgebra = "1"
MatrixAlgebraKit = "0.5.0"
OhMyThreads = "0.8.0"
PackageExtensionCompat = "1"
Printf = "1"
Random = "1"
SafeTestsets = "0.1"
ScopedValues = "1.3.0"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ blocks

To access the data associated with a specific fusion tree pair, you can use:
```@docs
Base.getindex(::TensorMap{T,S,N₁,N₂}, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}
Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}
Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree)
Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree)
```

For a tensor `t` with `FusionType(sectortype(t)) isa UniqueFusion`, fusion trees are
completely determined by the outcoming sectors, and the data can be accessed in a more
straightforward way:
```@docs
Base.getindex(::TensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector}
Base.getindex(::AbstractTensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector}
```

For tensor `t` with `sectortype(t) == Trivial`, the data can be accessed and manipulated
Expand Down
3 changes: 2 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space
# Export tensor map methods
export domain, codomain, numind, numout, numin, domainind, codomainind, allind
export spacetype, storagetype, scalartype, tensormaptype
export blocksectors, blockdim, block, blocks
export blocksectors, blockdim, block, blocks, subblocks, subblock

# random methods for constructor
export randisometry, randisometry!, rand, rand!, randn, randn!
Expand Down Expand Up @@ -127,6 +127,7 @@ using Base: @boundscheck, @propagate_inbounds, @constprop,
tuple_type_head, tuple_type_tail, tuple_type_cons,
SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype
using Base.Iterators: product, filter
using Printf: @sprintf

using LinearAlgebra: LinearAlgebra, BlasFloat
using LinearAlgebra: norm, dot, normalize, normalize!, tr,
Expand Down
60 changes: 46 additions & 14 deletions src/spaces/gradedspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,53 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I <: Sector
)
end

function Base.show(io::IO, V::GradedSpace{I}) where {I <: Sector}
print(io, type_repr(typeof(V)), "(")
separator = ""
comma = ", "
io2 = IOContext(io, :typeinfo => I)
for c in sectors(V)
if isdual(V)
print(io2, separator, dual(c), "=>", dim(V, c))
else
print(io2, separator, c, "=>", dim(V, c))
end
separator = comma
Base.summary(io::IO, V::GradedSpace) = print(io, type_repr(typeof(V)))

function Base.show(io::IO, V::GradedSpace)
opn = (get(io, :typeinfo, Any)::DataType == typeof(V) ? "" : type_repr(typeof(V)))
opn *= "("
if isdual(V)
cls = ")'"
V = dual(V)
else
cls = ")"
end

v = [c => dim(V, c) for c in sectors(V)]

# logic stolen from Base.show_vector
limited = get(io, :limit, false)::Bool
io = IOContext(io, :typeinfo => eltype(v))

if limited && length(v) > 20
axs1 = axes(v, 1)
f, l = first(axs1), last(axs1)
Base.show_delim_array(io, v, opn, ",", "", false, f, f + 9)
print(io, " … ")
Base.show_delim_array(io, v, "", ",", cls, false, l - 9, l)
else
Base.show_delim_array(io, v, opn, ",", cls, false)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is in principle a private function, though I am not opposed to using it. But it could easily be broken in a next Julia release.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that in principle we should probably vendor this ourselves. Unfortunately it is actually quite ugly/involved, so I would really prefer to not have to do so, and it might actually require less effort to simply update this function if Julia breaks it...

end
print(io, ")")
V.dual && print(io, "'")
return nothing
end

function Base.show(io::IO, ::MIME"text/plain", V::GradedSpace)
# print small summary, e.g.: Vect[I](…) of dim d
d = dim(V)
print(io, type_repr(typeof(d)), "(…)")
isdual(V) && print(io, "'")
print(io, " of dim ", d)

compact = get(io, :compact, false)::Bool
(iszero(d) || compact) && return nothing

# print detailed sector information - hijack Base.Vector printing
print(io, ":\n")
isdual(V) && (V = dual(V))
print_data = [c => dim(V, c) for c in sectors(V)]
ioc = IOContext(io, :typeinfo => eltype(print_data))
Base.print_matrix(ioc, print_data)

return nothing
end

Expand Down
161 changes: 161 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ Return an iterator over all splitting - fusion tree pairs of a tensor.
"""
fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist

fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t))
function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap}
I = sectortype(T)
return Tuple{fusiontreetype(I, numout(T)), fusiontreetype(I, numin(T))}
end

# auxiliary function
@inline function trivial_fusiontree(t::AbstractTensorMap)
sectortype(t) === Trivial ||
Expand Down Expand Up @@ -295,6 +301,126 @@ function blocktype(::Type{T}) where {T <: AbstractTensorMap}
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
end

# tensor data: subblock access
# ----------------------------
@doc """
subblocks(t::AbstractTensorMap)

Return an iterator over all subblocks of a tensor, i.e. all fusiontrees and their
corresponding tensor subblocks.

See also [`subblock`](@ref), [`fusiontrees`](@ref), and [`hassubblock`](@ref).
"""
subblocks(t::AbstractTensorMap) = SubblockIterator(t, fusiontrees(t))

const _doc_subblock = """
Return a view into the data of `t` corresponding to the splitting - fusion tree pair
`(f₁, f₂)`. In particular, this is an `AbstractArray{T}` with `T = scalartype(t)`, of size
`(dims(codomain(t), f₁.uncoupled)..., dims(codomain(t), f₂.uncoupled)...)`.

Whenever `FusionStyle(sectortype(t)) isa UniqueFusion` , it is also possible to provide only
the external `sectors`, in which case the fusion tree pair will be constructed automatically.
"""

@doc """
subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree,FusionTree})
subblock(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})

$_doc_subblock

In general, new tensor types should provide an implementation of this function for the
fusion tree signature.

See also [`subblocks`](@ref) and [`fusiontrees`](@ref).
""" subblock

Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector}
# input checking
I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor."))
FusionStyle(I) isa UniqueFusion ||
throw(SectorMismatch("Indexing with sectors is only possible for unique fusion styles."))
length(sectors) == numind(t) || throw(ArgumentError("invalid number of sectors"))

# convert to fusiontrees
s₁ = TupleTools.getindices(sectors, codomainind(t))
s₂ = map(dual, TupleTools.getindices(sectors, domainind(t)))
c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first(⊗(s₁...)))
@boundscheck begin
hassector(codomain(t), s₁) && hassector(domain(t), s₂) || throw(BoundsError(t, sectors))
c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first(⊗(s₂...)))
c2 == c1 || throw(SectorMismatch("Not a valid fusion channel for this tensor"))
end
f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...)))
f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...)))
return @inbounds subblock(t, (f₁, f₂))
end
Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple)
return subblock(t, map(Base.Fix1(convert, sectortype(t)), sectors))
end
# attempt to provide better error messages
function subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree})
(sectortype(t)) == sectortype(f₁) == sectortype(f₂) ||
throw(SectorMismatch("Not a valid sectortype for this tensor."))
numout(t) == length(f₁) && numin(t) == length(f₂) ||
throw(DimensionMismatch("Invalid number of fusiontree legs for this tensor."))
throw(MethodError(subblock, (t, (f₁, f₂))))
end

@doc """
subblocktype(t)
subblocktype(::Type{T})

Return the type of the tensor subblocks of a tensor.
""" subblocktype

function subblocktype(::Type{T}) where {T <: AbstractTensorMap}
return Core.Compiler.return_type(subblock, Tuple{T, fusiontreetype(T)})
end
subblocktype(t) = subblocktype(typeof(t))
subblocktype(T::Type) = throw(MethodError(subblocktype, (T,)))

# Indexing behavior
# -----------------
# by default getindex returns views!
@doc """
Base.getindex(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
t[sectors]
Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree)
t[f₁, f₂]

$_doc_subblock

!!! warning
Contrary to Julia's array types, the default behavior is to return a view into the tensor data.
As a result, modifying the view will modify the data in the tensor.

See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
""" Base.getindex(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector},
Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree)

@inline Base.getindex(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
subblock(t, sectors)
@inline Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) =
subblock(t, (f₁, f₂))

@doc """
Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{Vararg{Sector}})
t[sectors] = v
Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree)
t[f₁, f₂] = v

Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair `(f₁, f₂)`.
By default, `v` can be any object that can be copied into the view associated with `t[f₁, f₂]`.

See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
""" Base.setindex!(::AbstractTensorMap, ::Any, ::Tuple{I, Vararg{I}}) where {I <: Sector},
Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree)

@inline Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
copy!(subblock(t, sectors), v)
@inline Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree) =
copy!(subblock(t, (f₁, f₂)), v)

# Derived indexing behavior for tensors with trivial symmetry
#-------------------------------------------------------------
using TensorKit.Strided: SliceIndex
Expand Down Expand Up @@ -499,3 +625,38 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
return A
end
end

# Show and friends
# ----------------

function Base.dims2string(V::HomSpace)
str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×')
str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×')
return str_cod * "←" * str_dom
end

function Base.summary(io::IO, t::AbstractTensorMap)
V = space(t)
print(io, Base.dims2string(V), " ")
Base.showarg(io, t, true)
return nothing
end

# Human-readable:
function Base.show(io::IO, ::MIME"text/plain", t::AbstractTensorMap)
# 1) show summary: typically d₁×d₂×… ← d₃×d₄×… $(typeof(t)):
summary(io, t)
println(io, ":")

# 2) show spaces
# println(io, " space(t):")
println(io, " codomain: ", codomain(t))
println(io, " domain: ", domain(t))

# 3) [optional]: show data
get(io, :compact, true) && return nothing
ioc = IOContext(io, :typeinfo => sectortype(t))
println(io, "\n\n blocks: ")
show_blocks(io, MIME"text/plain"(), blocks(t))
return nothing
end
42 changes: 7 additions & 35 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,17 @@ function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector)
return adjoint(Base.getindex(iter.structure, c))
end

function Base.getindex(
t::AdjointTensorMap{T, S, N₁, N₂}, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
) where {T, S, N₁, N₂, I}
Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree})
tp = parent(t)
subblock = getindex(tp, f₂, f₁)
return permutedims(conj(subblock), (domainind(tp)..., codomainind(tp)...))
end
function Base.setindex!(
t::AdjointTensorMap{T, S, N₁, N₂}, v, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
) where {T, S, N₁, N₂, I}
return copy!(getindex(t, f₁, f₂), v)
data = subblock(tp, (f₂, f₁))
return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...))
end

# Show
#------
function Base.summary(io::IO, t::AdjointTensorMap)
return print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")")
end
function Base.show(io::IO, t::AdjointTensorMap)
if get(io, :compact, false)
print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")")
return
end
println(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), "):")
if sectortype(t) === Trivial
Base.print_array(io, t[])
println(io)
elseif FusionStyle(sectortype(t)) isa UniqueFusion
for (f₁, f₂) in fusiontrees(t)
println(io, "* Data for sector ", f₁.uncoupled, " ← ", f₂.uncoupled, ":")
Base.print_array(io, t[f₁, f₂])
println(io)
end
else
for (f₁, f₂) in fusiontrees(t)
println(io, "* Data for fusiontree ", f₁, " ← ", f₂, ":")
Base.print_array(io, t[f₁, f₂])
println(io)
end
end
function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool)
print(io, "adjoint(")
Base.showarg(io, parent(t), false)
print(io, ")")
return nothing
end
Loading
Loading