From 5bf99fd78cb8c5bed84433d25264c47e11c61523 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sun, 31 May 2026 10:54:34 +0200 Subject: [PATCH 1/3] Add `Base.repeat` methods for `Tile` --- src/language/operations.jl | 74 ++++++++++++++++++++++++++++++++++++++ test/device/tile.jl | 69 +++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/src/language/operations.jl b/src/language/operations.jl index 6328ce36..0831c4ab 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -786,6 +786,80 @@ reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements end @inline Base.reshape(tile::Tile{T}, dims::Int...) where {T} = reshape(tile, dims) +""" + repeat(tile::Tile, counts::Integer...) -> Tile + +Repeat a tile along each dimension (outer repetition), matching `Base.repeat`. +`counts[i]` is the number of times to tile the whole tile along dimension `i`; +dimensions beyond `length(counts)` are repeated once, and counts beyond +`ndims(tile)` introduce new trailing dimensions. + +`counts` must be compile-time constants, and every resulting dimension +`size(tile, i) * counts[i]` (as well as the new repeat dimensions) must be a +power of two. + +# Example +```julia +tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8) +tiled = repeat(tile, 2, 2) # Shape (8, 16) +``` +""" +@inline Base.repeat(tile::Tile, counts::Integer...) = + _repeat(tile, Val(:outer), Val(counts)) + +""" + repeat(tile::Tile; inner=nothing, outer=nothing) -> Tile + +Keyword form of `repeat`, matching `Base.repeat`. `inner[i]` repeats each element +`inner[i]` times along dimension `i` (inner repetition), while `outer[i]` tiles the +whole tile `outer[i]` times along dimension `i` (outer repetition). When both are +given, the inner repetition is applied first. Each of `inner`/`outer` may be an +`Integer` or a tuple, and must be a compile-time constant. + +# Example +```julia +tile = ct.load(arr, (1, 1), (2, 4)) # Shape (2, 4) +repeat(tile; inner=(2, 1), outer=(1, 2)) # Shape (4, 8) +``` +""" +@inline function Base.repeat(tile::Tile; inner = nothing, outer = nothing) + t = inner === nothing ? tile : + _repeat(tile, Val(:inner), Val(inner isa Integer ? (inner,) : Tuple(inner))) + return outer === nothing ? t : + _repeat(t, Val(:outer), Val(outer isa Integer ? (outer,) : Tuple(outer))) +end + +@generated function _repeat(tile::Tile, ::Val{Mode}, ::Val{Counts}) where {Mode, Counts} + sz = size(tile) + N = ndims(tile) + counts = Int.(Counts) + M = length(counts) + P = max(N, M) + szp = ntuple(i -> i <= N ? sz[i] : 1, P) + cntp = ntuple(i -> i <= M ? counts[i] : 1, P) + final = ntuple(i -> szp[i] * cntp[i], P) + + # Nothing to repeat and no new dimensions: identity. + all(==(1), cntp) && P == N && return :(tile) + + data_first = Mode === :outer + is_data(lo) = (lo == 0) == data_first + interleaved = ntuple(2P) do j + pair, lo = divrem(j - 1, 2) + is_data(lo) ? szp[pair + 1] : 1 + end + target = ntuple(2P) do j + pair, lo = divrem(j - 1, 2) + is_data(lo) ? szp[pair + 1] : cntp[pair + 1] + end + + return quote + t = reshape(tile, $interleaved) + t = broadcast_to(t, $target) + reshape(t, $final) + end +end + """ permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape} diff --git a/test/device/tile.jl b/test/device/tile.jl index 36628011..dac47764 100644 --- a/test/device/tile.jl +++ b/test/device/tile.jl @@ -473,6 +473,75 @@ end end end +@testset "repeat" begin + @testset "1D outer repeat matches Base.repeat" begin + function repeat_1d_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1}) + bid = ct.bid(1) + tile = ct.load(x, bid, (8,)) + ct.store(y, bid, repeat(tile, 4)) # (8,) -> (32,) + return + end + + x = CuArray(Float32.(1:8)) + y = CUDA.zeros(Float32, 32) + + @cuda backend=cuTile repeat_1d_kernel(x, y) + + @test Array(y) ≈ repeat(Float32.(1:8), 4) + end + + @testset "2D outer repeat matches Base.repeat" begin + function repeat_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile, 2, 2)) # (2, 4) -> (4, 8) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 8) + + @cuda backend=cuTile repeat_2d_kernel(x, y) + + @test Array(y) ≈ repeat(src, 2, 2) + end + + @testset "2D inner repeat matches Base.repeat" begin + function repeat_inner_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile; inner=(2, 1))) # (2, 4) -> (4, 4) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 4) + + @cuda backend=cuTile repeat_inner_kernel(x, y) + + @test Array(y) ≈ repeat(src; inner=(2, 1)) + end + + @testset "2D combined inner+outer repeat matches Base.repeat" begin + function repeat_inner_outer_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) + bid = ct.bid(1) + tile = ct.load(x, (bid, 1), (2, 4)) + ct.store(y, (bid, 1), repeat(tile; inner=(2, 1), outer=(1, 2))) # (2, 4) -> (4, 8) + return + end + + src = Float32.(reshape(1:8, 2, 4)) + x = CuArray(src) + y = CUDA.zeros(Float32, 4, 8) + + @cuda backend=cuTile repeat_inner_outer_kernel(x, y) + + @test Array(y) ≈ repeat(src; inner=(2, 1), outer=(1, 2)) + end +end + @testset "permutedims" begin @testset "2D permutedims (transpose-like)" begin function permute_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) From 8bf26c1ed8c1bceabcbee1f9fa713978ebf263c1 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 1 Jun 2026 18:23:50 +0200 Subject: [PATCH 2/3] Avoid generated functions and Val. --- src/language/operations.jl | 39 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index 0831c4ab..3f551976 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -804,8 +804,7 @@ tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8) tiled = repeat(tile, 2, 2) # Shape (8, 16) ``` """ -@inline Base.repeat(tile::Tile, counts::Integer...) = - _repeat(tile, Val(:outer), Val(counts)) +Base.repeat(tile::Tile, counts::Integer...) = repeat(tile; outer = counts) """ repeat(tile::Tile; inner=nothing, outer=nothing) -> Tile @@ -822,42 +821,44 @@ tile = ct.load(arr, (1, 1), (2, 4)) # Shape (2, 4) repeat(tile; inner=(2, 1), outer=(1, 2)) # Shape (4, 8) ``` """ -@inline function Base.repeat(tile::Tile; inner = nothing, outer = nothing) +function Base.repeat(tile::Tile; inner = nothing, outer = nothing) t = inner === nothing ? tile : - _repeat(tile, Val(:inner), Val(inner isa Integer ? (inner,) : Tuple(inner))) + _repeat(tile, :inner, inner isa Integer ? (inner,) : Tuple(inner)) return outer === nothing ? t : - _repeat(t, Val(:outer), Val(outer isa Integer ? (outer,) : Tuple(outer))) + _repeat(t, :outer, outer isa Integer ? (outer,) : Tuple(outer)) end -@generated function _repeat(tile::Tile, ::Val{Mode}, ::Val{Counts}) where {Mode, Counts} +# Implements both inner and outer repetition by reshaping the tile to interleave +# a singleton dimension next to each data dimension, broadcasting that singleton +# up to the repeat count, then collapsing each pair back together. For outer +# repetition the data dimension is the fast (inner) one within each pair; for +# inner repetition the repeat count is. Shapes derive from the tile type and the +# (constant) counts, so const-prop folds them away. +function _repeat(tile::Tile, mode::Symbol, counts::Tuple{Vararg{Integer}}) sz = size(tile) N = ndims(tile) - counts = Int.(Counts) M = length(counts) P = max(N, M) szp = ntuple(i -> i <= N ? sz[i] : 1, P) - cntp = ntuple(i -> i <= M ? counts[i] : 1, P) - final = ntuple(i -> szp[i] * cntp[i], P) + cntp = ntuple(i -> i <= M ? Int(counts[i]) : 1, P) # Nothing to repeat and no new dimensions: identity. - all(==(1), cntp) && P == N && return :(tile) + (all(==(1), cntp) && P == N) && return tile - data_first = Mode === :outer - is_data(lo) = (lo == 0) == data_first + data_first = mode === :outer interleaved = ntuple(2P) do j pair, lo = divrem(j - 1, 2) - is_data(lo) ? szp[pair + 1] : 1 + (lo == 0) == data_first ? szp[pair + 1] : 1 end target = ntuple(2P) do j pair, lo = divrem(j - 1, 2) - is_data(lo) ? szp[pair + 1] : cntp[pair + 1] + (lo == 0) == data_first ? szp[pair + 1] : cntp[pair + 1] end + final = ntuple(i -> szp[i] * cntp[i], P) - return quote - t = reshape(tile, $interleaved) - t = broadcast_to(t, $target) - reshape(t, $final) - end + t = reshape(tile, interleaved) + t = broadcast_to(t, target) + return reshape(t, final) end """ From 4fb8793ba164f1d6f62722f363b0170fd84b85f5 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 1 Jun 2026 18:37:56 +0200 Subject: [PATCH 3/3] Add `repeat` line; update ordering of operations --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a753db39..e53ce14b 100644 --- a/README.md +++ b/README.md @@ -184,12 +184,13 @@ Tile IR operations. ### Shape | Operation | Description | |-----------|-------------| -| `ct.broadcast_to(tile, shape)` | Broadcast to target shape | -| `transpose(tile)` | Transpose 2D tile | | `reshape(tile, shape)` | Reshape (same element count) | +| `transpose(tile)` | Transpose 2D tile | | `permutedims(tile, perm)` | Permute dimensions | +| `repeat(tile, counts...)` `repeat(tile; inner, outer)` | Repeat values along dimensions | | `ct.extract(tile, index, shape)` | Extract sub-tile | | `ct.cat((a, b), axis)` | Concatenate tiles | +| `ct.broadcast_to(tile, shape)` | Broadcast to target shape | | `dropdims(tile; dims)` | Remove singleton dimensions | ### Matrix