Skip to content

Commit 4da1880

Browse files
committed
some indexing reworking
1 parent 015760e commit 4da1880

3 files changed

Lines changed: 146 additions & 79 deletions

File tree

src/TensorKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space
5656

5757
# Export tensor map methods
5858
export domain, codomain, numind, numout, numin, domainind, codomainind, allind
59-
export spacetype, storagetype, scalartype, tensormaptype
60-
export blocksectors, blockdim, block, blocks
59+
export spacetype, sectortype, storagetype, scalartype, tensormaptype
60+
export blocksectors, blockdim, block, blocks, subblocks, subblock
6161

6262
# random methods for constructor
6363
export randisometry, randisometry!, rand, rand!, randn, randn!

src/tensors/abstracttensor.jl

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ Return an iterator over all splitting - fusion tree pairs of a tensor.
250250
"""
251251
fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist
252252

253+
fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t))
254+
function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap}
255+
I = sectortype(T)
256+
return Tuple{fusiontreetype(I, numout(T)), fusiontreetype(I, numin(T))}
257+
end
258+
253259
# auxiliary function
254260
@inline function trivial_fusiontree(t::AbstractTensorMap)
255261
sectortype(t) === Trivial ||
@@ -295,6 +301,137 @@ function blocktype(::Type{T}) where {T <: AbstractTensorMap}
295301
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
296302
end
297303

304+
# tensor data: subblock access
305+
# ----------------------------
306+
@doc """
307+
subblocks(t::AbstractTensorMap)
308+
309+
Return an iterator over all subblocks of a tensor, i.e. all fusiontrees and their
310+
corresponding tensor subblocks.
311+
312+
See also [`subblock`](@ref), [`fusiontrees`](@ref), and [`hassubblock`](@ref).
313+
"""
314+
subblocks(t::AbstractTensorMap) = SubblockIterator(t, fusiontrees(t))
315+
316+
const _doc_subblock = """
317+
Return a view into the data of `t` corresponding to the splitting - fusion tree pair
318+
`(f₁, f₂)`. In particular, this is an `AbstractArray{T}` with `T = scalartype(t)`, of size
319+
`(dims(codomain(t), f₁.uncoupled)..., dims(codomain(t), f₂.uncoupled)...)`.
320+
321+
Whenever `FusionStyle(sectortype(t))`, it is also possible to provide only the external
322+
`sectors`, in which case the fusion tree pair will be constructed automatically.
323+
"""
324+
325+
@doc """
326+
subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree,FusionTree})
327+
subblock(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
328+
329+
$_doc_subblock
330+
331+
In general, new tensor types should provide an implementation of this function for the
332+
fusion tree signature.
333+
334+
See also [`subblocks`](@ref) and [`fusiontrees`](@ref).
335+
""" subblock
336+
337+
Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector}
338+
# input checking
339+
I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor."))
340+
FusionStyle(I) isa UniqueFusion ||
341+
throw(SectorMismatch("Indexing with sectors is only possible for unique fusion styles."))
342+
length(sectors) == numind(t) || throw(ArgumentError("invalid number of sectors"))
343+
344+
# convert to fusiontrees
345+
s₁ = TupleTools.getindices(sectors, codomainind(t))
346+
s₂ = map(dual, TupleTools.getindices(sectors, domainind(t)))
347+
c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first((s₁...)))
348+
@boundscheck begin
349+
hassector(codomain(t), s₁) && hassector(domain(t), s₂) || throw(BoundsError(t, sectors))
350+
c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first((s₂...)))
351+
c2 == c1 || throw(SectorMismatch("Not a valid fusion channel for this tensor"))
352+
end
353+
f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...)))
354+
f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...)))
355+
return @inbounds subblock(t, (f₁, f₂))
356+
end
357+
Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple)
358+
return subblock(t, map(Base.Fix1(convert, sectortype(t)), sectors))
359+
end
360+
361+
@doc """
362+
subblocktype(t)
363+
subblocktype(::Type{T})
364+
365+
Return the type of the tensor subblocks of a tensor.
366+
""" subblocktype
367+
368+
function subblocktype(::Type{T}) where {T <: AbstractTensorMap}
369+
return Core.Compiler.return_type(subblock, Tuple{T, fusiontreetype(T)})
370+
end
371+
subblocktype(t) = subblocktype(typeof(t))
372+
subblocktype(T::Type) = throw(MethodError(subblocktype, (T,)))
373+
374+
# Indexing behavior
375+
# -----------------
376+
@doc """
377+
Base.view(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
378+
Base.view(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree)
379+
380+
$_doc_subblock
381+
382+
!!! note
383+
Contrary to Julia's array types, the default indexing behavior is to return a view
384+
into the tensor data. As a result, `getindex` and `view` have the same behavior.
385+
386+
See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
387+
""" Base.view(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector},
388+
Base.view(::AbstractTensorMap, ::FusionTree, ::FusionTree)
389+
390+
@inline Base.view(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
391+
subblock(t, sectors)
392+
@inline Base.view(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) =
393+
subblock(t, (f₁, f₂))
394+
395+
# by default getindex returns views
396+
@doc """
397+
Base.getindex(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
398+
t[sectors]
399+
Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree)
400+
t[f₁, f₂]
401+
402+
$_doc_subblock
403+
404+
!!! warning
405+
Contrary to Julia's array types, the default behavior is to return a view into the tensor data.
406+
As a result, modifying the view will modify the data in the tensor.
407+
408+
See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
409+
""" Base.getindex(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector},
410+
Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree)
411+
412+
@inline Base.getindex(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
413+
view(t, sectors)
414+
@inline Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) =
415+
view(t, f₁, f₂)
416+
417+
@doc """
418+
Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{Vararg{Sector}})
419+
t[sectors] = v
420+
Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree)
421+
t[f₁, f₂] = v
422+
423+
Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair `(f₁, f₂)`.
424+
By default, `v` can be any object that can be copied into the view associated with `t[f₁, f₂]`.
425+
426+
See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
427+
""" Base.setindex!(::AbstractTensorMap, ::Any, ::Tuple{I, Vararg{I}}) where {I <: Sector},
428+
Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree)
429+
430+
@inline Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
431+
copy!(view(t, sectors), v)
432+
@inline Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree) =
433+
copy!(view(t, (f₁, f₂)), v)
434+
298435
# Derived indexing behavior for tensors with trivial symmetry
299436
#-------------------------------------------------------------
300437
using TensorKit.Strided: SliceIndex

src/tensors/tensor.jl

Lines changed: 7 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -464,26 +464,10 @@ function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
464464
return reshape(view(iter.t.data, r), (d₁, d₂))
465465
end
466466

467-
# Indexing and getting and setting the data at the subblock level
468-
#-----------------------------------------------------------------
469-
"""
470-
Base.getindex(t::TensorMap{T,S,N₁,N₂,I},
471-
f₁::FusionTree{I,N₁},
472-
f₂::FusionTree{I,N₂}) where {T,SN₁,N₂,I<:Sector}
473-
-> StridedViews.StridedView
474-
t[f₁, f₂]
475-
476-
Return a view into the data slice of `t` corresponding to the splitting - fusion tree pair
477-
`(f₁, f₂)`. In particular, if `f₁.coupled == f₂.coupled == c`, then a
478-
`StridedViews.StridedView` of size
479-
`(dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled))` is returned which
480-
represents the slice of `block(t, c)` whose row indices correspond to `f₁.uncoupled` and
481-
column indices correspond to `f₂.uncoupled`.
482-
483-
See also [`Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}`](@ref)
484-
"""
485-
@inline function Base.getindex(
486-
t::TensorMap{T, S, N₁, N₂}, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
467+
# Getting and setting the data at the subblock level
468+
# --------------------------------------------------
469+
function subblock(
470+
t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}}
487471
) where {T, S, N₁, N₂, I <: Sector}
488472
structure = fusionblockstructure(t)
489473
@boundscheck begin
@@ -495,71 +479,17 @@ See also [`Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁
495479
return StridedView(t.data, sz, str, offset)
496480
end
497481
end
482+
498483
# The following is probably worth special casing for trivial tensors
499-
@inline function Base.getindex(
500-
t::TensorMap{T, S, N₁, N₂}, f₁::FusionTree{Trivial, N₁}, f₂::FusionTree{Trivial, N₂}
484+
@inline function subblock(
485+
t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{Trivial, N₁}, FusionTree{Trivial, N₂}}
501486
) where {T, S, N₁, N₂}
502487
@boundscheck begin
503488
sectortype(t) == Trivial || throw(SectorMismatch())
504489
end
505490
return sreshape(StridedView(t.data), (dims(codomain(t))..., dims(domain(t))...))
506491
end
507492

508-
"""
509-
Base.setindex!(t::TensorMap{T,S,N₁,N₂,I},
510-
v,
511-
f₁::FusionTree{I,N₁},
512-
f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}
513-
t[f₁, f₂] = v
514-
515-
Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair
516-
`(f₁, f₂)`. Here, `v` can be any object that can be copied into a `StridedViews.StridedView`
517-
of size `(dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled))` using
518-
`Base.copy!`.
519-
520-
See also [`Base.getindex(::TensorMap{T,S,N₁,N₂}, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}`](@ref)
521-
"""
522-
@propagate_inbounds function Base.setindex!(
523-
t::TensorMap{T, S, N₁, N₂}, v, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
524-
) where {T, S, N₁, N₂, I <: Sector}
525-
return copy!(getindex(t, f₁, f₂), v)
526-
end
527-
528-
"""
529-
Base.getindex(t::TensorMap
530-
sectors::NTuple{N₁+N₂,I}) where {N₁,N₂,I<:Sector}
531-
-> StridedViews.StridedView
532-
t[sectors]
533-
534-
Return a view into the data slice of `t` corresponding to the splitting - fusion tree pair
535-
with combined uncoupled charges `sectors`. In particular, if `sectors == (s₁..., s₂...)`
536-
where `s₁` and `s₂` correspond to the uncoupled charges in the codomain and domain
537-
respectively, then a `StridedViews.StridedView` of size
538-
`(dims(codomain(t), s₁)..., dims(domain(t), s₂))` is returned.
539-
540-
This method is only available for the case where `FusionStyle(I) isa UniqueFusion`,
541-
since it assumes a uniquely defined coupled charge.
542-
"""
543-
@inline function Base.getindex(t::TensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector}
544-
I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor."))
545-
FusionStyle(I) isa UniqueFusion ||
546-
throw(SectorMismatch("Indexing with sectors only possible if unique fusion"))
547-
length(sectors) == numind(t) ||
548-
throw(ArgumentError("Number of sectors does not match."))
549-
s₁ = TupleTools.getindices(sectors, codomainind(t))
550-
s₂ = map(dual, TupleTools.getindices(sectors, domainind(t)))
551-
c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first((s₁...)))
552-
@boundscheck begin
553-
c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first((s₂...)))
554-
c2 == c1 || throw(SectorMismatch("Not a valid sector for this tensor"))
555-
hassector(codomain(t), s₁) && hassector(domain(t), s₂)
556-
end
557-
f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...)))
558-
f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...)))
559-
@inbounds begin
560-
return t[f₁, f₂]
561-
end
562-
end
563493
@propagate_inbounds function Base.getindex(t::TensorMap, sectors::Tuple)
564494
return t[map(sectortype(t), sectors)]
565495
end

0 commit comments

Comments
 (0)