Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,13 @@ Tile IR operations.
### Shape
| Operation | Description |
|-----------|-------------|
| `ct.broadcast_to(tile, shape)` | Broadcast to target shape |
| `transpose(tile)` | Transpose 2D tile |
| `reshape(tile, shape)` | Reshape (same element count) |
| `transpose(tile)` | Transpose 2D tile |
| `permutedims(tile, perm)` | Permute dimensions |
| `repeat(tile, counts...)` `repeat(tile; inner, outer)` | Repeat values along dimensions |
| `ct.extract(tile, index, shape)` | Extract sub-tile |
| `ct.cat((a, b), axis)` | Concatenate tiles |
| `ct.broadcast_to(tile, shape)` | Broadcast to target shape |
| `dropdims(tile; dims)` | Remove singleton dimensions |

### Matrix
Expand Down
75 changes: 75 additions & 0 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,81 @@ reshaped = reshape(tile, (2, 16)) # Shape (2, 16), still 32 elements
end
@inline Base.reshape(tile::Tile{T}, dims::Int...) where {T} = reshape(tile, dims)

"""
repeat(tile::Tile, counts::Integer...) -> Tile

Repeat a tile along each dimension (outer repetition), matching `Base.repeat`.
`counts[i]` is the number of times to tile the whole tile along dimension `i`;
dimensions beyond `length(counts)` are repeated once, and counts beyond
`ndims(tile)` introduce new trailing dimensions.

`counts` must be compile-time constants, and every resulting dimension
`size(tile, i) * counts[i]` (as well as the new repeat dimensions) must be a
power of two.

# Example
```julia
tile = ct.load(arr, (1, 1), (4, 8)) # Shape (4, 8)
tiled = repeat(tile, 2, 2) # Shape (8, 16)
```
"""
Base.repeat(tile::Tile, counts::Integer...) = repeat(tile; outer = counts)

"""
repeat(tile::Tile; inner=nothing, outer=nothing) -> Tile

Keyword form of `repeat`, matching `Base.repeat`. `inner[i]` repeats each element
`inner[i]` times along dimension `i` (inner repetition), while `outer[i]` tiles the
whole tile `outer[i]` times along dimension `i` (outer repetition). When both are
given, the inner repetition is applied first. Each of `inner`/`outer` may be an
`Integer` or a tuple, and must be a compile-time constant.

# Example
```julia
tile = ct.load(arr, (1, 1), (2, 4)) # Shape (2, 4)
repeat(tile; inner=(2, 1), outer=(1, 2)) # Shape (4, 8)
```
"""
function Base.repeat(tile::Tile; inner = nothing, outer = nothing)
t = inner === nothing ? tile :
_repeat(tile, :inner, inner isa Integer ? (inner,) : Tuple(inner))
return outer === nothing ? t :
_repeat(t, :outer, outer isa Integer ? (outer,) : Tuple(outer))
end

# Implements both inner and outer repetition by reshaping the tile to interleave
# a singleton dimension next to each data dimension, broadcasting that singleton
# up to the repeat count, then collapsing each pair back together. For outer
# repetition the data dimension is the fast (inner) one within each pair; for
# inner repetition the repeat count is. Shapes derive from the tile type and the
# (constant) counts, so const-prop folds them away.
function _repeat(tile::Tile, mode::Symbol, counts::Tuple{Vararg{Integer}})
sz = size(tile)
N = ndims(tile)
M = length(counts)
P = max(N, M)
szp = ntuple(i -> i <= N ? sz[i] : 1, P)
cntp = ntuple(i -> i <= M ? Int(counts[i]) : 1, P)

# Nothing to repeat and no new dimensions: identity.
(all(==(1), cntp) && P == N) && return tile

data_first = mode === :outer
interleaved = ntuple(2P) do j
pair, lo = divrem(j - 1, 2)
(lo == 0) == data_first ? szp[pair + 1] : 1
end
target = ntuple(2P) do j
pair, lo = divrem(j - 1, 2)
(lo == 0) == data_first ? szp[pair + 1] : cntp[pair + 1]
end
final = ntuple(i -> szp[i] * cntp[i], P)

t = reshape(tile, interleaved)
t = broadcast_to(t, target)
return reshape(t, final)
end

"""
permutedims(tile::Tile{T, S}, perm) -> Tile{T, permuted_shape}

Expand Down
69 changes: 69 additions & 0 deletions test/device/tile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,75 @@ end
end
end

@testset "repeat" begin
@testset "1D outer repeat matches Base.repeat" begin
function repeat_1d_kernel(x::ct.TileArray{Float32,1}, y::ct.TileArray{Float32,1})
bid = ct.bid(1)
tile = ct.load(x, bid, (8,))
ct.store(y, bid, repeat(tile, 4)) # (8,) -> (32,)
return
end

x = CuArray(Float32.(1:8))
y = CUDA.zeros(Float32, 32)

@cuda backend=cuTile repeat_1d_kernel(x, y)

@test Array(y) ≈ repeat(Float32.(1:8), 4)
end

@testset "2D outer repeat matches Base.repeat" begin
function repeat_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
bid = ct.bid(1)
tile = ct.load(x, (bid, 1), (2, 4))
ct.store(y, (bid, 1), repeat(tile, 2, 2)) # (2, 4) -> (4, 8)
return
end

src = Float32.(reshape(1:8, 2, 4))
x = CuArray(src)
y = CUDA.zeros(Float32, 4, 8)

@cuda backend=cuTile repeat_2d_kernel(x, y)

@test Array(y) ≈ repeat(src, 2, 2)
end

@testset "2D inner repeat matches Base.repeat" begin
function repeat_inner_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
bid = ct.bid(1)
tile = ct.load(x, (bid, 1), (2, 4))
ct.store(y, (bid, 1), repeat(tile; inner=(2, 1))) # (2, 4) -> (4, 4)
return
end

src = Float32.(reshape(1:8, 2, 4))
x = CuArray(src)
y = CUDA.zeros(Float32, 4, 4)

@cuda backend=cuTile repeat_inner_kernel(x, y)

@test Array(y) ≈ repeat(src; inner=(2, 1))
end

@testset "2D combined inner+outer repeat matches Base.repeat" begin
function repeat_inner_outer_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
bid = ct.bid(1)
tile = ct.load(x, (bid, 1), (2, 4))
ct.store(y, (bid, 1), repeat(tile; inner=(2, 1), outer=(1, 2))) # (2, 4) -> (4, 8)
return
end

src = Float32.(reshape(1:8, 2, 4))
x = CuArray(src)
y = CUDA.zeros(Float32, 4, 8)

@cuda backend=cuTile repeat_inner_outer_kernel(x, y)

@test Array(y) ≈ repeat(src; inner=(2, 1), outer=(1, 2))
end
end

@testset "permutedims" begin
@testset "2D permutedims (transpose-like)" begin
function permute_2d_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
Expand Down