@@ -9,9 +9,12 @@ using CompilerCaching: CacheView, method_instance, results
99
1010import Core. Compiler as CC
1111
12- using CUDA: CuModule, CuFunction, cudacall, device, capability
12+ using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability
1313using CUDA_Compiler_jll
1414
15+ import Base. Broadcast: BroadcastStyle, Broadcasted, DefaultArrayStyle
16+ import CUDA: CuArrayStyle
17+
1518public launch
1619
1720function run_and_collect (cmd)
@@ -285,4 +288,120 @@ Other values pass through unchanged.
285288to_tile_arg (x) = x
286289to_tile_arg (arr:: AbstractArray ) = TileArray (arr)
287290
291+ #= ============================================================================
292+ Tiled Broadcast via Base.Broadcast
293+ =============================================================================#
294+
295+ """
296+ Tiled{A <: AbstractArray}
297+
298+ Wrapper that routes broadcast expressions through cuTile kernels.
299+
300+ Tiled(B) .= A .+ A
301+
302+ Uses Julia's `Base.Broadcast` fusion machinery to build a `Broadcasted` tree,
303+ then dispatches to a generic cuTile kernel that evaluates the tree on tiles.
304+ """
305+ struct _Tiled{A <: AbstractArray }
306+ parent:: A
307+ end
308+ Base. parent (t:: _Tiled ) = t. parent
309+ Base. size (t:: _Tiled ) = size (parent (t))
310+ Base. size (t:: _Tiled , d) = size (parent (t), d)
311+ Base. axes (t:: _Tiled ) = axes (parent (t))
312+ Base. axes (t:: _Tiled , d) = axes (parent (t), d)
313+ Base. ndims (:: _Tiled{A} ) where A = ndims (A)
314+ Base. eltype (:: _Tiled{A} ) where A = eltype (A)
315+ Base. length (t:: _Tiled ) = length (parent (t))
316+ Base. similar (t:: _Tiled , args... ) = _Tiled (similar (parent (t), args... ))
317+ Base. setindex! (t:: _Tiled , v, i... ) = setindex! (parent (t), v, i... )
318+
319+ cuTile. Tiled (arr:: AbstractArray ) = _Tiled (arr)
320+
321+ struct TiledCuArrayStyle{N} <: BroadcastStyle end
322+ TiledCuArrayStyle {M} (:: Val{N} ) where {N,M} = TiledCuArrayStyle {N} ()
323+
324+ BroadcastStyle (:: Type{<:_Tiled{<:CuArray{T,N}}} ) where {T,N} = TiledCuArrayStyle {N} ()
325+
326+ # TiledCuArrayStyle wins over CuArrayStyle and DefaultArrayStyle
327+ BroadcastStyle (:: TiledCuArrayStyle{N} , :: CuArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
328+ BroadcastStyle (:: TiledCuArrayStyle{N} , :: DefaultArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
329+ BroadcastStyle (:: TiledCuArrayStyle{N} , :: TiledCuArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
330+
331+ # materialize! dispatch: Tiled(B) .= expr
332+ function Base. Broadcast. materialize! (dest:: _Tiled , bc:: Broadcasted )
333+ _tiled_broadcast! (parent (dest), bc)
334+ return dest
335+ end
336+
337+ """
338+ _to_tiled_bc(bc)
339+
340+ Walk a Broadcasted tree, converting leaf CuArrays to TileArrays and stripping
341+ style/axes (replacing with nothing). Scalars and other leaves pass through.
342+ """
343+ _to_tiled_bc (arr:: CuArray ) = TileArray (arr)
344+ _to_tiled_bc (t:: _Tiled ) = TileArray (parent (t))
345+ _to_tiled_bc (x:: Number ) = x
346+ _to_tiled_bc (x) = x # fallback for other types
347+ function _to_tiled_bc (bc:: Broadcasted )
348+ new_args = map (_to_tiled_bc, bc. args)
349+ Broadcasted {Nothing} (bc. f, new_args, nothing )
350+ end
351+
352+ # The generic broadcast kernel: evaluates the Broadcasted tree on tiles
353+ function _tiled_bc_kernel_1d (dest:: TileArray{T, 1} , bc, tile_size) where T
354+ bid = cuTile. bid (1 )
355+ result = _eval_bc (bc, bid, tile_size)
356+ result_converted = convert (cuTile. Tile{T}, result)
357+ cuTile. store (dest, bid, result_converted)
358+ return
359+ end
360+
361+ function _tiled_bc_kernel_2d (dest:: TileArray{T, 2} , bc, tile_size) where T
362+ bid_x = cuTile. bid (1 )
363+ bid_y = cuTile. bid (2 )
364+ result = _eval_bc (bc, (bid_x, bid_y), tile_size)
365+ result_converted = convert (cuTile. Tile{T}, result)
366+ cuTile. store (dest, (bid_x, bid_y), result_converted)
367+ return
368+ end
369+
370+ # Recursive tree evaluation inside kernel
371+ @inline _eval_bc (arr:: TileArray , bid, tile_size) = cuTile. load (arr, bid, tile_size)
372+ @inline _eval_bc (x:: Number , bid, tile_size) = x
373+
374+ @inline function _eval_bc (bc:: Broadcasted , bid, tile_size)
375+ args = _eval_bc_args (bc. args, bid, tile_size)
376+ # Use broadcast to get element-wise semantics (not direct call, which
377+ # would dispatch to e.g. matmul for * on tiles)
378+ broadcast (bc. f, args... )
379+ end
380+
381+ @inline _eval_bc_args (:: Tuple{} , bid, tile_size) = ()
382+ @inline _eval_bc_args (args:: Tuple , bid, tile_size) =
383+ (_eval_bc (args[1 ], bid, tile_size), _eval_bc_args (Base. tail (args), bid, tile_size)... )
384+
385+ """
386+ _tiled_broadcast!(dest, bc; tile_size=64)
387+
388+ Launch a tiled broadcast kernel for the fused expression `bc` writing to `dest`.
389+ """
390+ function _tiled_broadcast! (dest:: CuArray{T,N} , bc:: Broadcasted ; tile_size:: Int = 64 ) where {T, N}
391+ dest_ta = TileArray (dest)
392+ tiled_bc = _to_tiled_bc (bc)
393+
394+ if N == 1
395+ ts = (tile_size,)
396+ grid = (cld (size (dest, 1 ), tile_size),)
397+ cuTile. launch (_tiled_bc_kernel_1d, grid, dest_ta, tiled_bc, Constant (ts))
398+ elseif N == 2
399+ ts = (tile_size, tile_size)
400+ grid = (cld (size (dest, 1 ), tile_size), cld (size (dest, 2 ), tile_size))
401+ cuTile. launch (_tiled_bc_kernel_2d, grid, dest_ta, tiled_bc, Constant (ts))
402+ else
403+ error (" Tiled broadcast not yet supported for $N dimensions" )
404+ end
405+ end
406+
288407end
0 commit comments