Skip to content

Commit b876749

Browse files
authored
AbstractBlockPermutation <: AbstractBlockTuple (#11)
1 parent 81818e2 commit b876749

9 files changed

Lines changed: 280 additions & 182 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/TensorAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ module TensorAlgebra
22

33
export contract, contract!
44

5+
include("blockedtuple.jl")
56
include("blockedpermutation.jl")
67
include("BaseExtensions/BaseExtensions.jl")
7-
include("blockedtuple.jl")
88
include("fusedims.jl")
99
include("splitdims.jl")
1010
include("contract/contract.jl")

src/blockedpermutation.jl

Lines changed: 66 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,82 +12,32 @@ end
1212
_flatten_tuples() = ()
1313
flatten_tuples(ts::Tuple) = _flatten_tuples(ts...)
1414

15-
_blocklength(blocklengths::Tuple{Vararg{Int}}) = length(blocklengths)
16-
function _blockfirsts(blocklengths::Tuple{Vararg{Int}})
17-
return ntuple(_blocklength(blocklengths)) do i
18-
prev_blocklast =
19-
isone(i) ? zero(eltype(blocklengths)) : _blocklasts(blocklengths)[i - 1]
20-
return prev_blocklast + 1
21-
end
22-
end
23-
_blocklasts(blocklengths::Tuple{Vararg{Int}}) = cumsum(blocklengths)
24-
2515
collect_tuple(x) = (x,)
2616
collect_tuple(x::Ellipsis) = x
2717
collect_tuple(t::Tuple) = t
2818

29-
const TupleOfTuples{N} = Tuple{Vararg{Tuple{Vararg{Int}},N}}
30-
31-
abstract type AbstractBlockedPermutation{BlockLength,Length} end
32-
33-
BlockArrays.blocks(blockedperm::AbstractBlockedPermutation) = error("Not implemented")
34-
35-
function Base.Tuple(blockedperm::AbstractBlockedPermutation)
36-
return flatten_tuples(blocks(blockedperm))
37-
end
38-
39-
function BlockArrays.blocklengths(blockedperm::AbstractBlockedPermutation)
40-
return length.(blocks(blockedperm))
41-
end
42-
43-
function BlockArrays.blockfirsts(blockedperm::AbstractBlockedPermutation)
44-
return _blockfirsts(blocklengths(blockedperm))
45-
end
46-
47-
function BlockArrays.blocklasts(blockedperm::AbstractBlockedPermutation)
48-
return _blocklasts(blocklengths(blockedperm))
49-
end
19+
#
20+
# =============================== AbstractBlockPermutation ===============================
21+
#
22+
abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockLength} end
5023

51-
Base.iterate(permblocks::AbstractBlockedPermutation) = iterate(Tuple(permblocks))
52-
function Base.iterate(permblocks::AbstractBlockedPermutation, state)
53-
return iterate(Tuple(permblocks), state)
54-
end
24+
widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple
5525

5626
# Block a permutation based on the specified lengths.
5727
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
5828
# TODO: Optimize with StaticNumbers.jl or generated functions, see:
5929
# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
6030
function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
61-
starts = _blockfirsts(blocklengths)
62-
stops = _blocklasts(blocklengths)
63-
return blockedperm(ntuple(i -> perm[starts[i]:stops[i]], length(blocklengths))...)
64-
end
65-
66-
function Base.invperm(blockedperm::AbstractBlockedPermutation)
67-
return blockperm(invperm(Tuple(blockedperm)), blocklengths(blockedperm))
31+
return blockedperm(BlockedTuple(perm, blocklengths))
6832
end
6933

70-
Base.length(blockedperm::AbstractBlockedPermutation) = length(Tuple(blockedperm))
71-
function BlockArrays.blocklength(blockedperm::AbstractBlockedPermutation)
72-
return length(blocks(blockedperm))
34+
function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
35+
return blockedperm(BlockedTuple(perm, BlockLengths))
7336
end
7437

75-
function Base.getindex(blockedperm::AbstractBlockedPermutation, i::Int)
76-
return Tuple(blockedperm)[i]
77-
end
78-
79-
function Base.getindex(blockedperm::AbstractBlockedPermutation, I::AbstractUnitRange)
80-
perm = Tuple(blockedperm)
81-
return [perm[i] for i in I]
82-
end
83-
84-
function Base.getindex(blockedperm::AbstractBlockedPermutation, b::Block)
85-
return blocks(blockedperm)[Int(b)]
86-
end
87-
88-
# Like `BlockRange`.
89-
function blockeachindex(blockedperm::AbstractBlockedPermutation)
90-
return ntuple(i -> Block(i), blocklength(blockedperm))
38+
function Base.invperm(blockedperm::AbstractBlockPermutation)
39+
# use Val to preserve compile time info
40+
return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm)))
9141
end
9242

9343
#
@@ -97,7 +47,7 @@ end
9747
# Bipartition a vector according to the
9848
# bipartitioned permutation.
9949
# Like `Base.permute!` block out-of-place and blocked.
100-
function blockpermute(v, blockedperm::AbstractBlockedPermutation)
50+
function blockpermute(v, blockedperm::AbstractBlockPermutation)
10151
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
10252
end
10353

@@ -106,8 +56,8 @@ function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothin
10656
return blockedperm(length, permblocks...)
10757
end
10858

109-
function blockedperm(length::Nothing, permblocks::Tuple{Vararg{Int}}...)
110-
return blockedperm(Val(sum(Base.length, permblocks; init=zero(Bool))), permblocks...)
59+
function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...)
60+
return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
11161
end
11262

11363
# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
@@ -119,11 +69,15 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwar
11969
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
12070
end
12171

72+
function blockedperm(bt::AbstractBlockTuple)
73+
return blockedperm(Val(length(bt)), blocks(bt)...)
74+
end
75+
12276
function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
12377
return maximum(specified_perm)
12478
end
12579

126-
function _blockedperm_length(vallength::Val, specified_perm::Tuple{Vararg{Int}})
80+
function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
12781
return value(vallength)
12882
end
12983

@@ -148,45 +102,69 @@ function blockedperm_indexin(collection, subs...)
148102
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
149103
end
150104

151-
struct BlockedPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
152-
AbstractBlockedPermutation{BlockLength,Length}
153-
blocks::Blocks
154-
global function _BlockedPermutation(blocks::TupleOfTuples)
155-
len = sum(length, blocks; init=zero(Bool))
156-
blocklength = length(blocks)
157-
return new{blocklength,len,typeof(blocks)}(blocks)
105+
#
106+
# ================================== BlockedPermutation ==================================
107+
#
108+
109+
# for dispatch reason, it is convenient to have BlockLength as the first parameter
110+
struct BlockedPermutation{BlockLength,BlockLengths,Flat} <:
111+
AbstractBlockPermutation{BlockLength}
112+
flat::Flat
113+
114+
function BlockedPermutation{BlockLength,BlockLengths}(
115+
flat::Tuple
116+
) where {BlockLength,BlockLengths}
117+
length(flat) != sum(BlockLengths; init=0) &&
118+
throw(DimensionMismatch("Invalid total length"))
119+
length(BlockLengths) != BlockLength &&
120+
throw(DimensionMismatch("Invalid total blocklength"))
121+
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
122+
return new{BlockLength,BlockLengths,typeof(flat)}(flat)
158123
end
159124
end
160125

161-
BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)
126+
# Base interface
127+
Base.Tuple(blockedperm::BlockedPermutation) = getfield(blockedperm, :flat)
162128

163-
function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
164-
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
165-
blockedperm = _BlockedPermutation(permblocks)
129+
# BlockArrays interface
130+
function BlockArrays.blocklengths(
131+
::Type{<:BlockedPermutation{<:Any,BlockLengths}}
132+
) where {BlockLengths}
133+
return BlockLengths
134+
end
135+
136+
function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
137+
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
138+
flatten_tuples(permblocks)
139+
)
166140
@assert isperm(blockedperm)
167141
return blockedperm
168142
end
169143

144+
#
145+
# ============================== BlockedTrivialPermutation ===============================
146+
#
170147
trivialperm(length::Union{Integer,Val}) = ntuple(identity, length)
171148

172-
struct BlockedTrivialPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
173-
AbstractBlockedPermutation{BlockLength,Length}
174-
blocks::Blocks
175-
global function _BlockedTrivialPermutation(blocklengths::Tuple{Vararg{Int}})
176-
len = sum(blocklengths; init=zero(Bool))
177-
blocklength = length(blocklengths)
178-
permblocks = blocks(blockperm(trivialperm(len), blocklengths))
179-
return new{blocklength,len,typeof(permblocks)}(permblocks)
180-
end
149+
struct BlockedTrivialPermutation{BlockLength,BlockLengths} <:
150+
AbstractBlockPermutation{BlockLength} end
151+
152+
Base.Tuple(blockedperm::BlockedTrivialPermutation) = trivialperm(length(blockedperm))
153+
154+
# BlockArrays interface
155+
function BlockArrays.blocklengths(
156+
::Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}}
157+
) where {BlockLengths}
158+
return BlockLengths
181159
end
182160

183-
BlockArrays.blocks(blockedperm::BlockedTrivialPermutation) = getfield(blockedperm, :blocks)
161+
blockedperm(tp::BlockedTrivialPermutation) = tp
184162

185163
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
186-
return _BlockedTrivialPermutation(blocklengths)
164+
return BlockedTrivialPermutation{length(blocklengths),blocklengths}()
187165
end
188166

189-
function trivialperm(blockedperm::AbstractBlockedPermutation)
167+
function trivialperm(blockedperm::AbstractBlockTuple)
190168
return blockedtrivialperm(blocklengths(blockedperm))
191169
end
192170
Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm

src/blockedtuple.jl

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
2-
# like interface
1+
# This file defines an abstract type AbstractBlockTuple and a concrete type BlockedTuple.
2+
# These types allow to store a Tuple of heterogeneous Tuples with a BlockArrays.jl like
3+
# interface.
34

45
using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange
56

@@ -8,7 +9,17 @@ using TypeParameterAccessors: unspecify_type_parameters
89
#
910
# ================================== AbstractBlockTuple ==================================
1011
#
11-
abstract type AbstractBlockTuple end
12+
# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch
13+
# it makes no assumotion on storage type
14+
abstract type AbstractBlockTuple{BlockLength} end
15+
16+
constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
17+
widened_constructorof(type::Type{<:AbstractBlockTuple}) = constructorof(type)
18+
19+
# Like `BlockRange`.
20+
function blockeachindex(bt::AbstractBlockTuple)
21+
return ntuple(i -> Block(i), blocklength(bt))
22+
end
1223

1324
# Base interface
1425
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
@@ -22,9 +33,8 @@ Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
2233
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
2334
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
2435
r = Int.(br)
25-
T = unspecify_type_parameters(typeof(bt))
2636
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
27-
return T{blocklengths(bt)[r]}(flat)
37+
return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r])
2838
end
2939
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
3040
return bt[Block(bi)][only(bi.indices)]
@@ -33,12 +43,14 @@ end
3343
Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
3444
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)
3545

36-
Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))
37-
3846
Base.lastindex(bt::AbstractBlockTuple) = length(bt)
3947

48+
Base.length(bt::AbstractBlockTuple) = sum(blocklengths(bt); init=0)
49+
4050
function Base.map(f, bt::AbstractBlockTuple)
41-
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))
51+
BL = blocklengths(bt)
52+
# use Val to preserve compile time knowledge of BL
53+
return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL))
4254
end
4355

4456
# Broadcast interface
@@ -57,19 +69,20 @@ end
5769
function Base.copy(
5870
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
5971
) where {BlockLengths,BT}
60-
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
72+
return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
6173
end
6274

6375
# BlockArrays interface
76+
BlockArrays.blockfirsts(::AbstractBlockTuple{0}) = ()
6477
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
6578
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1
6679
end
6780

6881
function BlockArrays.blocklasts(bt::AbstractBlockTuple)
69-
return cumsum(blocklengths(bt)[begin:end])
82+
return cumsum(blocklengths(bt))
7083
end
7184

72-
BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))
85+
BlockArrays.blocklength(::AbstractBlockTuple{BlockLength}) where {BlockLength} = BlockLength
7386

7487
BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))
7588

@@ -79,29 +92,46 @@ function BlockArrays.blocks(bt::AbstractBlockTuple)
7992
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
8093
end
8194

82-
#
95+
# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))
96+
8397
# ===================================== BlockedTuple =====================================
8498
#
85-
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
99+
struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength}
86100
flat::Flat
87101

88-
function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
89-
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
90-
return new{BlockLengths,typeof(flat)}(flat)
102+
function BlockedTuple{BlockLength,BlockLengths}(
103+
flat::Tuple
104+
) where {BlockLength,BlockLengths}
105+
length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))
106+
length(flat) != sum(BlockLengths; init=0) &&
107+
throw(DimensionMismatch("Invalid total length"))
108+
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
109+
return new{BlockLength,BlockLengths,typeof(flat)}(flat)
91110
end
92111
end
93112

94113
# TensorAlgebra Interface
95-
tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
114+
function tuplemortar(tt::Tuple{Vararg{Tuple}})
115+
return BlockedTuple{length(tt),length.(tt)}(flatten_tuples(tt))
116+
end
96117
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
97-
return BlockedTuple{BlockLengths}(flat)
118+
return BlockedTuple{length(BlockLengths),BlockLengths}(flat)
119+
end
120+
function BlockedTuple(flat::Tuple, ::Val{BlockLengths}) where {BlockLengths}
121+
# use Val to preserve compile time knowledge of BL
122+
return BlockedTuple{length(BlockLengths),BlockLengths}(flat)
123+
end
124+
function BlockedTuple(bt::AbstractBlockTuple)
125+
bl = blocklengths(bt)
126+
return BlockedTuple{length(bl),bl}(Tuple(bt))
98127
end
99-
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))
100128

101129
# Base interface
102130
Base.Tuple(bt::BlockedTuple) = bt.flat
103131

104132
# BlockArrays interface
105-
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
133+
function BlockArrays.blocklengths(
134+
::Type{<:BlockedTuple{<:Any,BlockLengths}}
135+
) where {BlockLengths}
106136
return BlockLengths
107137
end

src/fusedims.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ function fusedims(a::AbstractArray, permblocks...)
5151
end
5252

5353
function fuseaxes(
54-
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockedPermutation
54+
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
5555
)
5656
axesblocks = blockpermute(axes, blockedperm)
5757
return map(block -> (block...), axesblocks)
5858
end
5959

60-
function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockedPermutation)
60+
function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockPermutation)
6161
return fuseaxes(axes(a), blockedperm)
6262
end
6363

0 commit comments

Comments
 (0)