Skip to content

Commit 53bd1a7

Browse files
Add Base.repeat methods for Tile (#244)
* Add `Base.repeat` methods for `Tile` * Avoid generated functions and Val. * Add `repeat` line; update ordering of operations --------- Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent fdc4162 commit 53bd1a7

3 files changed

Lines changed: 147 additions & 2 deletions

File tree

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,13 @@ Tile IR operations.
184184
### Shape
185185
| Operation | Description |
186186
|-----------|-------------|
187-
| `ct.broadcast_to(tile, shape)` | Broadcast to target shape |
188-
| `transpose(tile)` | Transpose 2D tile |
189187
| `reshape(tile, shape)` | Reshape (same element count) |
188+
| `transpose(tile)` | Transpose 2D tile |
190189
| `permutedims(tile, perm)` | Permute dimensions |
190+
| `repeat(tile, counts...)` `repeat(tile; inner, outer)` | Repeat values along dimensions |
191191
| `ct.extract(tile, index, shape)` | Extract sub-tile |
192192
| `ct.cat((a, b), axis)` | Concatenate tiles |
193+
| `ct.broadcast_to(tile, shape)` | Broadcast to target shape |
193194
| `dropdims(tile; dims)` | Remove singleton dimensions |
194195

195196
### Matrix

src/language/operations.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,81 @@ reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements
786786
end
787787
@inline Base.reshape(tile::Tile{T}, dims::Int...) where {T} = reshape(tile, dims)
788788

789+
"""
790+
repeat(tile::Tile, counts::Integer...) -> Tile
791+
792+
Repeat a tile along each dimension (outer repetition), matching `Base.repeat`.
793+
`counts[i]` is the number of times to tile the whole tile along dimension `i`;
794+
dimensions beyond `length(counts)` are repeated once, and counts beyond
795+
`ndims(tile)` introduce new trailing dimensions.
796+
797+
`counts` must be compile-time constants, and every resulting dimension
798+
`size(tile, i) * counts[i]` (as well as the new repeat dimensions) must be a
799+
power of two.
800+
801+
# Example
802+
```julia
803+
tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8)
804+
tiled = repeat(tile, 2, 2) # Shape (8, 16)
805+
```
806+
"""
807+
Base.repeat(tile::Tile, counts::Integer...) = repeat(tile; outer = counts)
808+
809+
"""
810+
repeat(tile::Tile; inner=nothing, outer=nothing) -> Tile
811+
812+
Keyword form of `repeat`, matching `Base.repeat`. `inner[i]` repeats each element
813+
`inner[i]` times along dimension `i` (inner repetition), while `outer[i]` tiles the
814+
whole tile `outer[i]` times along dimension `i` (outer repetition). When both are
815+
given, the inner repetition is applied first. Each of `inner`/`outer` may be an
816+
`Integer` or a tuple, and must be a compile-time constant.
817+
818+
# Example
819+
```julia
820+
tile = ct.load(arr, (1, 1), (2, 4)) # Shape (2, 4)
821+
repeat(tile; inner=(2, 1), outer=(1, 2)) # Shape (4, 8)
822+
```
823+
"""
824+
function Base.repeat(tile::Tile; inner = nothing, outer = nothing)
825+
t = inner === nothing ? tile :
826+
_repeat(tile, :inner, inner isa Integer ? (inner,) : Tuple(inner))
827+
return outer === nothing ? t :
828+
_repeat(t, :outer, outer isa Integer ? (outer,) : Tuple(outer))
829+
end
830+
831+
# Implements both inner and outer repetition by reshaping the tile to interleave
832+
# a singleton dimension next to each data dimension, broadcasting that singleton
833+
# up to the repeat count, then collapsing each pair back together. For outer
834+
# repetition the data dimension is the fast (inner) one within each pair; for
835+
# inner repetition the repeat count is. Shapes derive from the tile type and the
836+
# (constant) counts, so const-prop folds them away.
837+
function _repeat(tile::Tile, mode::Symbol, counts::Tuple{Vararg{Integer}})
838+
sz = size(tile)
839+
N = ndims(tile)
840+
M = length(counts)
841+
P = max(N, M)
842+
szp = ntuple(i -> i <= N ? sz[i] : 1, P)
843+
cntp = ntuple(i -> i <= M ? Int(counts[i]) : 1, P)
844+
845+
# Nothing to repeat and no new dimensions: identity.
846+
(all(==(1), cntp) && P == N) && return tile
847+
848+
data_first = mode === :outer
849+
interleaved = ntuple(2P) do j
850+
pair, lo = divrem(j - 1, 2)
851+
(lo == 0) == data_first ? szp[pair + 1] : 1
852+
end
853+
target = ntuple(2P) do j
854+
pair, lo = divrem(j - 1, 2)
855+
(lo == 0) == data_first ? szp[pair + 1] : cntp[pair + 1]
856+
end
857+
final = ntuple(i -> szp[i] * cntp[i], P)
858+
859+
t = reshape(tile, interleaved)
860+
t = broadcast_to(t, target)
861+
return reshape(t, final)
862+
end
863+
789864
"""
790865
permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape}
791866

test/device/tile.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,75 @@ end
473473
end
474474
end
475475

476+
@testset "repeat" begin
477+
@testset "1D outer repeat matches Base.repeat" begin
478+
function repeat_1d_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1})
479+
bid = ct.bid(1)
480+
tile = ct.load(x, bid, (8,))
481+
ct.store(y, bid, repeat(tile, 4)) # (8,) -> (32,)
482+
return
483+
end
484+
485+
x = CuArray(Float32.(1:8))
486+
y = CUDA.zeros(Float32, 32)
487+
488+
@cuda backend=cuTile repeat_1d_kernel(x, y)
489+
490+
@test Array(y) repeat(Float32.(1:8), 4)
491+
end
492+
493+
@testset "2D outer repeat matches Base.repeat" begin
494+
function repeat_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
495+
bid = ct.bid(1)
496+
tile = ct.load(x, (bid, 1), (2, 4))
497+
ct.store(y, (bid, 1), repeat(tile, 2, 2)) # (2, 4) -> (4, 8)
498+
return
499+
end
500+
501+
src = Float32.(reshape(1:8, 2, 4))
502+
x = CuArray(src)
503+
y = CUDA.zeros(Float32, 4, 8)
504+
505+
@cuda backend=cuTile repeat_2d_kernel(x, y)
506+
507+
@test Array(y) repeat(src, 2, 2)
508+
end
509+
510+
@testset "2D inner repeat matches Base.repeat" begin
511+
function repeat_inner_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
512+
bid = ct.bid(1)
513+
tile = ct.load(x, (bid, 1), (2, 4))
514+
ct.store(y, (bid, 1), repeat(tile; inner=(2, 1))) # (2, 4) -> (4, 4)
515+
return
516+
end
517+
518+
src = Float32.(reshape(1:8, 2, 4))
519+
x = CuArray(src)
520+
y = CUDA.zeros(Float32, 4, 4)
521+
522+
@cuda backend=cuTile repeat_inner_kernel(x, y)
523+
524+
@test Array(y) repeat(src; inner=(2, 1))
525+
end
526+
527+
@testset "2D combined inner+outer repeat matches Base.repeat" begin
528+
function repeat_inner_outer_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
529+
bid = ct.bid(1)
530+
tile = ct.load(x, (bid, 1), (2, 4))
531+
ct.store(y, (bid, 1), repeat(tile; inner=(2, 1), outer=(1, 2))) # (2, 4) -> (4, 8)
532+
return
533+
end
534+
535+
src = Float32.(reshape(1:8, 2, 4))
536+
x = CuArray(src)
537+
y = CUDA.zeros(Float32, 4, 8)
538+
539+
@cuda backend=cuTile repeat_inner_outer_kernel(x, y)
540+
541+
@test Array(y) repeat(src; inner=(2, 1), outer=(1, 2))
542+
end
543+
end
544+
476545
@testset "permutedims" begin
477546
@testset "2D permutedims (transpose-like)" begin
478547
function permute_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})

0 commit comments

Comments
 (0)