Skip to content

Commit 3b1300a

Browse files
maleadtclaude
andcommitted
Add tiled broadcast via Base.Broadcast integration
Implements `ct.Tiled(B) .= A .+ A .* B` syntax that leverages Julia's broadcast fusion machinery and dispatches to cuTile kernels. - `Tiled` wrapper type with `TiledCuArrayStyle` that wins over `CuArrayStyle` and `DefaultArrayStyle` - `materialize!` converts the fused `Broadcasted` tree: CuArrays become TileArrays, style/axes are stripped - Generic 1D/2D kernels recursively evaluate the `Broadcasted` tree on tiles, using `broadcast(bc.f, args...)` for element-wise semantics - Supports arbitrarily nested fused expressions (e.g. `A .+ A .* B`) Type-constructor broadcasts (e.g. `BFloat16.(A)`) are not yet supported due to `Type{T}` fields causing compilation issues. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e33c254 commit 3b1300a

2 files changed

Lines changed: 124 additions & 2 deletions

File tree

ext/CUDAExt.jl

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ using CompilerCaching: CacheView, method_instance, results
99

1010
import Core.Compiler as CC
1111

12-
using CUDA: CuModule, CuFunction, cudacall, device, capability
12+
using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability
1313
using CUDA_Compiler_jll
1414

15+
import Base.Broadcast: BroadcastStyle, Broadcasted, DefaultArrayStyle
16+
import CUDA: CuArrayStyle
17+
1518
public launch
1619

1720
function run_and_collect(cmd)
@@ -285,4 +288,120 @@ Other values pass through unchanged.
285288
to_tile_arg(x) = x
286289
to_tile_arg(arr::AbstractArray) = TileArray(arr)
287290

291+
#=============================================================================
292+
Tiled Broadcast via Base.Broadcast
293+
=============================================================================#
294+
295+
"""
296+
Tiled{A <: AbstractArray}
297+
298+
Wrapper that routes broadcast expressions through cuTile kernels.
299+
300+
Tiled(B) .= A .+ A
301+
302+
Uses Julia's `Base.Broadcast` fusion machinery to build a `Broadcasted` tree,
303+
then dispatches to a generic cuTile kernel that evaluates the tree on tiles.
304+
"""
305+
struct _Tiled{A <: AbstractArray}
306+
parent::A
307+
end
308+
Base.parent(t::_Tiled) = t.parent
309+
Base.size(t::_Tiled) = size(parent(t))
310+
Base.size(t::_Tiled, d) = size(parent(t), d)
311+
Base.axes(t::_Tiled) = axes(parent(t))
312+
Base.axes(t::_Tiled, d) = axes(parent(t), d)
313+
Base.ndims(::_Tiled{A}) where A = ndims(A)
314+
Base.eltype(::_Tiled{A}) where A = eltype(A)
315+
Base.length(t::_Tiled) = length(parent(t))
316+
Base.similar(t::_Tiled, args...) = _Tiled(similar(parent(t), args...))
317+
Base.setindex!(t::_Tiled, v, i...) = setindex!(parent(t), v, i...)
318+
319+
cuTile.Tiled(arr::AbstractArray) = _Tiled(arr)
320+
321+
struct TiledCuArrayStyle{N} <: BroadcastStyle end
322+
TiledCuArrayStyle{M}(::Val{N}) where {N,M} = TiledCuArrayStyle{N}()
323+
324+
BroadcastStyle(::Type{<:_Tiled{<:CuArray{T,N}}}) where {T,N} = TiledCuArrayStyle{N}()
325+
326+
# TiledCuArrayStyle wins over CuArrayStyle and DefaultArrayStyle
327+
BroadcastStyle(::TiledCuArrayStyle{N}, ::CuArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
328+
BroadcastStyle(::TiledCuArrayStyle{N}, ::DefaultArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
329+
BroadcastStyle(::TiledCuArrayStyle{N}, ::TiledCuArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
330+
331+
# materialize! dispatch: Tiled(B) .= expr
332+
function Base.Broadcast.materialize!(dest::_Tiled, bc::Broadcasted)
333+
_tiled_broadcast!(parent(dest), bc)
334+
return dest
335+
end
336+
337+
"""
338+
_to_tiled_bc(bc)
339+
340+
Walk a Broadcasted tree, converting leaf CuArrays to TileArrays and stripping
341+
style/axes (replacing with nothing). Scalars and other leaves pass through.
342+
"""
343+
_to_tiled_bc(arr::CuArray) = TileArray(arr)
344+
_to_tiled_bc(t::_Tiled) = TileArray(parent(t))
345+
_to_tiled_bc(x::Number) = x
346+
_to_tiled_bc(x) = x # fallback for other types
347+
function _to_tiled_bc(bc::Broadcasted)
348+
new_args = map(_to_tiled_bc, bc.args)
349+
Broadcasted{Nothing}(bc.f, new_args, nothing)
350+
end
351+
352+
# The generic broadcast kernel: evaluates the Broadcasted tree on tiles
353+
function _tiled_bc_kernel_1d(dest::TileArray{T, 1}, bc, tile_size) where T
354+
bid = cuTile.bid(1)
355+
result = _eval_bc(bc, bid, tile_size)
356+
result_converted = convert(cuTile.Tile{T}, result)
357+
cuTile.store(dest, bid, result_converted)
358+
return
359+
end
360+
361+
function _tiled_bc_kernel_2d(dest::TileArray{T, 2}, bc, tile_size) where T
362+
bid_x = cuTile.bid(1)
363+
bid_y = cuTile.bid(2)
364+
result = _eval_bc(bc, (bid_x, bid_y), tile_size)
365+
result_converted = convert(cuTile.Tile{T}, result)
366+
cuTile.store(dest, (bid_x, bid_y), result_converted)
367+
return
368+
end
369+
370+
# Recursive tree evaluation inside kernel
371+
@inline _eval_bc(arr::TileArray, bid, tile_size) = cuTile.load(arr, bid, tile_size)
372+
@inline _eval_bc(x::Number, bid, tile_size) = x
373+
374+
@inline function _eval_bc(bc::Broadcasted, bid, tile_size)
375+
args = _eval_bc_args(bc.args, bid, tile_size)
376+
# Use broadcast to get element-wise semantics (not direct call, which
377+
# would dispatch to e.g. matmul for * on tiles)
378+
broadcast(bc.f, args...)
379+
end
380+
381+
@inline _eval_bc_args(::Tuple{}, bid, tile_size) = ()
382+
@inline _eval_bc_args(args::Tuple, bid, tile_size) =
383+
(_eval_bc(args[1], bid, tile_size), _eval_bc_args(Base.tail(args), bid, tile_size)...)
384+
385+
"""
386+
_tiled_broadcast!(dest, bc; tile_size=64)
387+
388+
Launch a tiled broadcast kernel for the fused expression `bc` writing to `dest`.
389+
"""
390+
function _tiled_broadcast!(dest::CuArray{T,N}, bc::Broadcasted; tile_size::Int=64) where {T, N}
391+
dest_ta = TileArray(dest)
392+
tiled_bc = _to_tiled_bc(bc)
393+
394+
if N == 1
395+
ts = (tile_size,)
396+
grid = (cld(size(dest, 1), tile_size),)
397+
cuTile.launch(_tiled_bc_kernel_1d, grid, dest_ta, tiled_bc, Constant(ts))
398+
elseif N == 2
399+
ts = (tile_size, tile_size)
400+
grid = (cld(size(dest, 1), tile_size), cld(size(dest, 2), tile_size))
401+
cuTile.launch(_tiled_bc_kernel_2d, grid, dest_ta, tiled_bc, Constant(ts))
402+
else
403+
error("Tiled broadcast not yet supported for $N dimensions")
404+
end
405+
end
406+
288407
end

src/cuTile.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ include("language/operations.jl")
3939
include("language/atomics.jl")
4040
include("language/broadcast_macro.jl")
4141

42-
public launch, ByTarget, @compiler_options
42+
public launch, Tiled, ByTarget, @compiler_options
4343
launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.")
4444

45+
# Tiled(arr) is defined in CUDAExt; provide a function stub for the public API
46+
function Tiled end
47+
4548
end # module cuTile

0 commit comments

Comments
 (0)