1212_flatten_tuples () = ()
1313flatten_tuples (ts:: Tuple ) = _flatten_tuples (ts... )
1414
15- _blocklength (blocklengths:: Tuple{Vararg{Int}} ) = length (blocklengths)
16- function _blockfirsts (blocklengths:: Tuple{Vararg{Int}} )
17- return ntuple (_blocklength (blocklengths)) do i
18- prev_blocklast =
19- isone (i) ? zero (eltype (blocklengths)) : _blocklasts (blocklengths)[i - 1 ]
20- return prev_blocklast + 1
21- end
22- end
23- _blocklasts (blocklengths:: Tuple{Vararg{Int}} ) = cumsum (blocklengths)
24-
2515collect_tuple (x) = (x,)
2616collect_tuple (x:: Ellipsis ) = x
2717collect_tuple (t:: Tuple ) = t
2818
29- const TupleOfTuples{N} = Tuple{Vararg{Tuple{Vararg{Int}},N}}
30-
31- abstract type AbstractBlockedPermutation{BlockLength,Length} end
32-
33- BlockArrays. blocks (blockedperm:: AbstractBlockedPermutation ) = error (" Not implemented" )
34-
35- function Base. Tuple (blockedperm:: AbstractBlockedPermutation )
36- return flatten_tuples (blocks (blockedperm))
37- end
38-
39- function BlockArrays. blocklengths (blockedperm:: AbstractBlockedPermutation )
40- return length .(blocks (blockedperm))
41- end
42-
43- function BlockArrays. blockfirsts (blockedperm:: AbstractBlockedPermutation )
44- return _blockfirsts (blocklengths (blockedperm))
45- end
46-
47- function BlockArrays. blocklasts (blockedperm:: AbstractBlockedPermutation )
48- return _blocklasts (blocklengths (blockedperm))
49- end
19+ #
20+ # =============================== AbstractBlockPermutation ===============================
21+ #
22+ abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockLength} end
5023
51- Base. iterate (permblocks:: AbstractBlockedPermutation ) = iterate (Tuple (permblocks))
52- function Base. iterate (permblocks:: AbstractBlockedPermutation , state)
53- return iterate (Tuple (permblocks), state)
54- end
24+ widened_constructorof (:: Type{<:AbstractBlockPermutation} ) = BlockedTuple
5525
5626# Block a permutation based on the specified lengths.
5727# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
5828# TODO : Optimize with StaticNumbers.jl or generated functions, see:
5929# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
6030function blockperm (perm:: Tuple{Vararg{Int}} , blocklengths:: Tuple{Vararg{Int}} )
61- starts = _blockfirsts (blocklengths)
62- stops = _blocklasts (blocklengths)
63- return blockedperm (ntuple (i -> perm[starts[i]: stops[i]], length (blocklengths))... )
64- end
65-
66- function Base. invperm (blockedperm:: AbstractBlockedPermutation )
67- return blockperm (invperm (Tuple (blockedperm)), blocklengths (blockedperm))
31+ return blockedperm (BlockedTuple (perm, blocklengths))
6832end
6933
70- Base. length (blockedperm:: AbstractBlockedPermutation ) = length (Tuple (blockedperm))
71- function BlockArrays. blocklength (blockedperm:: AbstractBlockedPermutation )
72- return length (blocks (blockedperm))
34+ function blockperm (perm:: Tuple{Vararg{Int}} , BlockLengths:: Val )
35+ return blockedperm (BlockedTuple (perm, BlockLengths))
7336end
7437
75- function Base. getindex (blockedperm:: AbstractBlockedPermutation , i:: Int )
76- return Tuple (blockedperm)[i]
77- end
78-
79- function Base. getindex (blockedperm:: AbstractBlockedPermutation , I:: AbstractUnitRange )
80- perm = Tuple (blockedperm)
81- return [perm[i] for i in I]
82- end
83-
84- function Base. getindex (blockedperm:: AbstractBlockedPermutation , b:: Block )
85- return blocks (blockedperm)[Int (b)]
86- end
87-
88- # Like `BlockRange`.
89- function blockeachindex (blockedperm:: AbstractBlockedPermutation )
90- return ntuple (i -> Block (i), blocklength (blockedperm))
38+ function Base. invperm (blockedperm:: AbstractBlockPermutation )
39+ # use Val to preserve compile time info
40+ return blockperm (invperm (Tuple (blockedperm)), Val (blocklengths (blockedperm)))
9141end
9242
9343#
9747# Bipartition a vector according to the
9848# bipartitioned permutation.
9949# Like `Base.permute!` block out-of-place and blocked.
100- function blockpermute (v, blockedperm:: AbstractBlockedPermutation )
50+ function blockpermute (v, blockedperm:: AbstractBlockPermutation )
10151 return map (blockperm -> map (i -> v[i], blockperm), blocks (blockedperm))
10252end
10353
@@ -106,8 +56,8 @@ function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothin
10656 return blockedperm (length, permblocks... )
10757end
10858
109- function blockedperm (length :: Nothing , permblocks:: Tuple{Vararg{Int}} ...)
110- return blockedperm (Val (sum (Base . length, permblocks; init= zero (Bool))), permblocks... )
59+ function blockedperm (:: Nothing , permblocks:: Tuple{Vararg{Int}} ...)
60+ return blockedperm (Val (sum (length, permblocks; init= zero (Bool))), permblocks... )
11161end
11262
11363# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
@@ -119,11 +69,15 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwar
11969 return blockedperm (collect_tuple .(permblocks)... ; kwargs... )
12070end
12171
72+ function blockedperm (bt:: AbstractBlockTuple )
73+ return blockedperm (Val (length (bt)), blocks (bt)... )
74+ end
75+
12276function _blockedperm_length (:: Nothing , specified_perm:: Tuple{Vararg{Int}} )
12377 return maximum (specified_perm)
12478end
12579
126- function _blockedperm_length (vallength:: Val , specified_perm :: Tuple{Vararg{Int}} )
80+ function _blockedperm_length (vallength:: Val , :: Tuple{Vararg{Int}} )
12781 return value (vallength)
12882end
12983
@@ -148,45 +102,69 @@ function blockedperm_indexin(collection, subs...)
148102 return blockedperm (map (sub -> BaseExtensions. indexin (sub, collection), subs)... )
149103end
150104
151- struct BlockedPermutation{BlockLength,Length,Blocks<: TupleOfTuples{BlockLength} } < :
152- AbstractBlockedPermutation{BlockLength,Length}
153- blocks:: Blocks
154- global function _BlockedPermutation (blocks:: TupleOfTuples )
155- len = sum (length, blocks; init= zero (Bool))
156- blocklength = length (blocks)
157- return new {blocklength,len,typeof(blocks)} (blocks)
105+ #
106+ # ================================== BlockedPermutation ==================================
107+ #
108+
109+ # for dispatch reason, it is convenient to have BlockLength as the first parameter
110+ struct BlockedPermutation{BlockLength,BlockLengths,Flat} < :
111+ AbstractBlockPermutation{BlockLength}
112+ flat:: Flat
113+
114+ function BlockedPermutation {BlockLength,BlockLengths} (
115+ flat:: Tuple
116+ ) where {BlockLength,BlockLengths}
117+ length (flat) != sum (BlockLengths; init= 0 ) &&
118+ throw (DimensionMismatch (" Invalid total length" ))
119+ length (BlockLengths) != BlockLength &&
120+ throw (DimensionMismatch (" Invalid total blocklength" ))
121+ any (BlockLengths .< 0 ) && throw (DimensionMismatch (" Invalid block length" ))
122+ return new {BlockLength,BlockLengths,typeof(flat)} (flat)
158123 end
159124end
160125
161- BlockArrays. blocks (blockedperm:: BlockedPermutation ) = getfield (blockedperm, :blocks )
126+ # Base interface
127+ Base. Tuple (blockedperm:: BlockedPermutation ) = getfield (blockedperm, :flat )
162128
163- function blockedperm (length:: Val , permblocks:: Tuple{Vararg{Int}} ...)
164- @assert value (length) == sum (Base. length, permblocks; init= zero (Bool))
165- blockedperm = _BlockedPermutation (permblocks)
129+ # BlockArrays interface
130+ function BlockArrays. blocklengths (
131+ :: Type{<:BlockedPermutation{<:Any,BlockLengths}}
132+ ) where {BlockLengths}
133+ return BlockLengths
134+ end
135+
136+ function blockedperm (:: Val , permblocks:: Tuple{Vararg{Int}} ...)
137+ blockedperm = BlockedPermutation {length(permblocks),length.(permblocks)} (
138+ flatten_tuples (permblocks)
139+ )
166140 @assert isperm (blockedperm)
167141 return blockedperm
168142end
169143
144+ #
145+ # ============================== BlockedTrivialPermutation ===============================
146+ #
170147trivialperm (length:: Union{Integer,Val} ) = ntuple (identity, length)
171148
172- struct BlockedTrivialPermutation{BlockLength,Length,Blocks<: TupleOfTuples{BlockLength} } < :
173- AbstractBlockedPermutation{BlockLength,Length}
174- blocks:: Blocks
175- global function _BlockedTrivialPermutation (blocklengths:: Tuple{Vararg{Int}} )
176- len = sum (blocklengths; init= zero (Bool))
177- blocklength = length (blocklengths)
178- permblocks = blocks (blockperm (trivialperm (len), blocklengths))
179- return new {blocklength,len,typeof(permblocks)} (permblocks)
180- end
149+ struct BlockedTrivialPermutation{BlockLength,BlockLengths} < :
150+ AbstractBlockPermutation{BlockLength} end
151+
152+ Base. Tuple (blockedperm:: BlockedTrivialPermutation ) = trivialperm (length (blockedperm))
153+
154+ # BlockArrays interface
155+ function BlockArrays. blocklengths (
156+ :: Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}}
157+ ) where {BlockLengths}
158+ return BlockLengths
181159end
182160
183- BlockArrays . blocks ( blockedperm:: BlockedTrivialPermutation ) = getfield (blockedperm, :blocks )
161+ blockedperm (tp :: BlockedTrivialPermutation ) = tp
184162
185163function blockedtrivialperm (blocklengths:: Tuple{Vararg{Int}} )
186- return _BlockedTrivialPermutation (blocklengths)
164+ return BlockedTrivialPermutation {length (blocklengths),blocklengths} ( )
187165end
188166
189- function trivialperm (blockedperm:: AbstractBlockedPermutation )
167+ function trivialperm (blockedperm:: AbstractBlockTuple )
190168 return blockedtrivialperm (blocklengths (blockedperm))
191169end
192170Base. invperm (blockedperm:: BlockedTrivialPermutation ) = blockedperm
0 commit comments