diff --git a/README.md b/README.md index a753db39..e53ce14b 100644 --- a/README.md +++ b/README.md @@ -184,12 +184,13 @@ Tile IR operations. ### Shape | Operation | Description | |-----------|-------------| -| `ct.broadcast_to(tile, shape)` | Broadcast to target shape | -| `transpose(tile)` | Transpose 2D tile | | `reshape(tile, shape)` | Reshape (same element count) | +| `transpose(tile)` | Transpose 2D tile | | `permutedims(tile, perm)` | Permute dimensions | +| `repeat(tile, counts...)` `repeat(tile; inner, outer)` | Repeat values along dimensions | | `ct.extract(tile, index, shape)` | Extract sub-tile | | `ct.cat((a, b), axis)` | Concatenate tiles | +| `ct.broadcast_to(tile, shape)` | Broadcast to target shape | | `dropdims(tile; dims)` | Remove singleton dimensions | ### Matrix diff --git a/src/language/operations.jl b/src/language/operations.jl index 6328ce36..3f551976 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -786,6 +786,81 @@ reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements end @inline Base.reshape(tile::Tile{T}, dims::Int...) where {T} = reshape(tile, dims) +""" + repeat(tile::Tile, counts::Integer...) -> Tile + +Repeat a tile along each dimension (outer repetition), matching `Base.repeat`. +`counts[i]` is the number of times to tile the whole tile along dimension `i`; +dimensions beyond `length(counts)` are repeated once, and counts beyond +`ndims(tile)` introduce new trailing dimensions. + +`counts` must be compile-time constants, and every resulting dimension +`size(tile, i) * counts[i]` (as well as the new repeat dimensions) must be a +power of two. + +# Example +```julia +tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8) +tiled = repeat(tile, 2, 2) # Shape (8, 16) +``` +""" +Base.repeat(tile::Tile, counts::Integer...) = repeat(tile; outer = counts) + +""" + repeat(tile::Tile; inner=nothing, outer=nothing) -> Tile + +Keyword form of `repeat`, matching `Base.repeat`. `inner[i]` repeats each element +`inner[i]` times along dimension `i` (inner repetition), while `outer[i]` tiles the +whole tile `outer[i]` times along dimension `i` (outer repetition). When both are +given, the inner repetition is applied first. Each of `inner`/`outer` may be an +`Integer` or a tuple, and must be a compile-time constant. + +# Example +```julia +tile = ct.load(arr, (1, 1), (2, 4)) # Shape (2, 4) +repeat(tile; inner=(2, 1), outer=(1, 2)) # Shape (4, 8) +``` +""" +function Base.repeat(tile::Tile; inner = nothing, outer = nothing) + t = inner === nothing ? tile : + _repeat(tile, :inner, inner isa Integer ? (inner,) : Tuple(inner)) + return outer === nothing ? t : + _repeat(t, :outer, outer isa Integer ? (outer,) : Tuple(outer)) +end + +# Implements both inner and outer repetition by reshaping the tile to interleave +# a singleton dimension next to each data dimension, broadcasting that singleton +# up to the repeat count, then collapsing each pair back together. For outer +# repetition the data dimension is the fast (inner) one within each pair; for +# inner repetition the repeat count is. Shapes derive from the tile type and the +# (constant) counts, so const-prop folds them away. +function _repeat(tile::Tile, mode::Symbol, counts::Tuple{Vararg{Integer}}) + sz = size(tile) + N = ndims(tile) + M = length(counts) + P = max(N, M) + szp = ntuple(i -> i <= N ? sz[i] : 1, P) + cntp = ntuple(i -> i <= M ? Int(counts[i]) : 1, P) + + # Nothing to repeat and no new dimensions: identity. + (all(==(1), cntp) && P == N) && return tile + + data_first = mode === :outer + interleaved = ntuple(2P) do j + pair, lo = divrem(j - 1, 2) + (lo == 0) == data_first ? szp[pair + 1] : 1 + end + target = ntuple(2P) do j + pair, lo = divrem(j - 1, 2) + (lo == 0) == data_first ? szp[pair + 1] : cntp[pair + 1] + end + final = ntuple(i -> szp[i] * cntp[i], P) + + t = reshape(tile, interleaved) + t = broadcast_to(t, target) + return reshape(t, final) +end + """ permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape} diff --git a/test/device/tile.jl b/test/device/tile.jl index 36628011..dac47764 100644 --- a/test/device/tile.jl +++ b/test/device/tile.jl @@ -473,6 +473,75 @@ end end end +@testset "repeat" begin + @testset "1D outer repeat matches Base.repeat" begin + function repeat_1d_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1}) + bid = ct.bid(1) + tile = ct.load(x, bid, (8,)) + ct.store(y, bid, repeat(tile, 4)) # (8,) -> (32,) + return + end + + x = CuArray(Float32.(1:8)) + y = CUDA.zeros(Float32, 32) + + @cuda backend=cuTile repeat_1d_kernel(x, y) + + @test Array(y) ≈ repeat(Float32.(1:8), 4) + end + + @testset "2D outer repeat matches Base.repeat" begin + function repeat_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile, 2, 2)) # (2, 4) -> (4, 8) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 8) + + @cuda backend=cuTile repeat_2d_kernel(x, y) + + @test Array(y) ≈ repeat(src, 2, 2) + end + + @testset "2D inner repeat matches Base.repeat" begin + function repeat_inner_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile; inner=(2, 1))) # (2, 4) -> (4, 4) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 4) + + @cuda backend=cuTile repeat_inner_kernel(x, y) + + @test Array(y) ≈ repeat(src; inner=(2, 1)) + end + + @testset "2D combined inner+outer repeat matches Base.repeat" begin + function repeat_inner_outer_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile; inner=(2, 1), outer=(1, 2))) # (2, 4) -> (4, 8) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 8) + + @cuda backend=cuTile repeat_inner_outer_kernel(x, y) + + @test Array(y) ≈ repeat(src; inner=(2, 1), outer=(1, 2)) + end +end + @testset "permutedims" begin @testset "2D permutedims (transpose-like)" begin function permute_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})