Skip to content

Commit c4b4757

Browse files
maleadtclaude
andcommitted
Generalize tiled broadcast kernel to N dimensions
Replace separate 1D/2D broadcast kernels with a single @generated kernel that handles arbitrary dimensionality, matching the @fuse macro's bid-construction pattern for N>3 grid delinearization. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 828c5e3 commit c4b4757

1 file changed

Lines changed: 35 additions & 25 deletions

File tree

ext/CUDAExt.jl

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -350,21 +350,34 @@ function _to_tiled_bc(bc::Broadcasted)
350350
end
351351

352352
# The generic broadcast kernel: evaluates the Broadcasted tree on tiles
353-
function _tiled_bc_kernel_1d(dest::TileArray{T, 1}, bc, tile_size) where T
354-
bid = cuTile.bid(1)
355-
result = _eval_bc(bc, bid, tile_size)
356-
result_converted = convert(cuTile.Tile{T}, result)
357-
cuTile.store(dest, bid, result_converted)
358-
return
359-
end
353+
@generated function _tiled_bc_kernel(dest::TileArray{T, N}, bc, tile_size, overflow_grids) where {T, N}
354+
body = Expr[]
355+
bid_vars = [Symbol("bid_$d") for d in 1:N]
360356

361-
function _tiled_bc_kernel_2d(dest::TileArray{T, 2}, bc, tile_size) where T
362-
bid_x = cuTile.bid(1)
363-
bid_y = cuTile.bid(2)
364-
result = _eval_bc(bc, (bid_x, bid_y), tile_size)
365-
result_converted = convert(cuTile.Tile{T}, result)
366-
cuTile.store(dest, (bid_x, bid_y), result_converted)
367-
return
357+
if N <= 3
358+
for d in 1:N
359+
push!(body, :($(bid_vars[d]) = cuTile.bid($d)))
360+
end
361+
else
362+
push!(body, :($(bid_vars[1]) = cuTile.bid(1)))
363+
push!(body, :($(bid_vars[2]) = cuTile.bid(2)))
364+
push!(body, :(_rem = cuTile.bid(3) - Int32(1)))
365+
for d in 3:N
366+
if d < N
367+
push!(body, :($(bid_vars[d]) = rem(_rem, Int32(overflow_grids[$(d-2)])) + Int32(1)))
368+
push!(body, :(_rem = fld(_rem, Int32(overflow_grids[$(d-2)]))))
369+
else
370+
push!(body, :($(bid_vars[d]) = _rem + Int32(1)))
371+
end
372+
end
373+
end
374+
375+
idx = N == 1 ? bid_vars[1] : Expr(:tuple, bid_vars...)
376+
push!(body, :(result = _eval_bc(bc, $idx, tile_size)))
377+
push!(body, :(result_converted = convert(cuTile.Tile{$T}, result)))
378+
push!(body, :(cuTile.store(dest, $idx, result_converted)))
379+
push!(body, :(return))
380+
Expr(:block, body...)
368381
end
369382

370383
# Recursive tree evaluation inside kernel
@@ -391,17 +404,14 @@ function _tiled_broadcast!(dest::CuArray{T,N}, bc::Broadcasted; tile_size::Int=6
391404
dest_ta = TileArray(dest)
392405
tiled_bc = _to_tiled_bc(bc)
393406

394-
if N == 1
395-
ts = (tile_size,)
396-
grid = (cld(size(dest, 1), tile_size),)
397-
cuTile.launch(_tiled_bc_kernel_1d, grid, dest_ta, tiled_bc, Constant(ts))
398-
elseif N == 2
399-
ts = (tile_size, tile_size)
400-
grid = (cld(size(dest, 1), tile_size), cld(size(dest, 2), tile_size))
401-
cuTile.launch(_tiled_bc_kernel_2d, grid, dest_ta, tiled_bc, Constant(ts))
402-
else
403-
error("Tiled broadcast not yet supported for $N dimensions")
404-
end
407+
ts = ntuple(i -> i <= min(N, 2) ? tile_size : 1, N)
408+
grid = ntuple(i -> cld(size(dest, i), ts[i]), N)
409+
410+
launch_grid = N <= 3 ? grid : (grid[1], grid[2], prod(grid[i] for i in 3:N))
411+
overflow = N > 3 ? grid[3:end] : ()
412+
413+
cuTile.launch(_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc,
414+
Constant(ts), Constant(overflow))
405415
end
406416

407417
end

0 commit comments

Comments
 (0)