Skip to content

Auto-generating fused broadcast kernels#121

Closed
AntonOresten wants to merge 4 commits into
JuliaGPU:mainfrom
AntonOresten:fused-broadcast
Closed

Auto-generating fused broadcast kernels#121
AntonOresten wants to merge 4 commits into
JuliaGPU:mainfrom
AntonOresten:fused-broadcast

Conversation

@AntonOresten

@AntonOresten AntonOresten commented Mar 16, 2026

Copy link
Copy Markdown
Collaborator

This is more of an issue-with-supplementary-code than a PR, and should not necessarily be in cuTile.jl, but I had this idea to procedurally generate cuTile kernels for fused broadcasting like Julia already does. My main motivation here though is leveraging tileiras for e.g. converting FP8 activations to arithmetic types then writing back to FP8 again, but it seems to also perform reasonably well for more mundane things:

julia> f!(B, A) = ct.@fuse B .= A .+ A # or ct.@. B = A + A
f! (generic function with 1 method)

julia> A = CUDA.rand(1024,1024,1024); B = similar(A);

julia> @be CUDA.@sync B .= A .+ A
Benchmark: 8 samples with 1 evaluation
 min    11.314 ms (313 allocs: 6.234 KiB)
 median 12.129 ms (313 allocs: 6.234 KiB)
 mean   12.704 ms (313 allocs: 6.234 KiB, 5.18% gc time)
 max    14.765 ms (313 allocs: 6.234 KiB, 27.19% gc time)

julia> @be CUDA.@sync f!(B, A)
Benchmark: 10 samples with 1 evaluation
 min    10.495 ms (447 allocs: 11.250 KiB)
 median 10.507 ms (447 allocs: 11.250 KiB)
 mean   10.518 ms (447 allocs: 11.250 KiB, 6.34% gc time)
 max    10.606 ms (447 allocs: 11.250 KiB, 51.37% gc time)

julia> B == A * 2
true

Currently it creates 8 methods to accommodate different destination dimensionalities, so it dispatches based on ndims(dest) at launch time (could probably be a @generated function instead). Tile sizes default is simply (64, 64, 1, 1, ...). Singleton broadcasting, type conversions (e.g. BFloat16.(x)), scalars, etc. all work.

I haven't tested it extensively, but it might be fairly robust. On it's own it isn't too useful outside of leveraging tileiras. I wonder if anything like this could be used for epilogues.

I suppose a macro-less version could be done by intercepting Base.Broadcast. maybe with method tables by passing a function to a function. The macro is straight-forward enough. Curious to hear your thoughts though @maleadt.

@maleadt

maleadt commented Mar 17, 2026

Copy link
Copy Markdown
Member

We'll definitely need ways to expose cuTile.jl to users rather than kernel developers, so this is definitely a good first step. It's a bit unfortunate we're not re-using Julia's broadcast fusion here. I wonder if we could have a macro to only redirect the broadcast style to something cuTile.jl-specific and otherwise reuse more of the existing machinery? A separate array type would be easiest, but I'm not sure we want that. Or maybe it would be fine if it's only a wrapper to influence dispatch (here, broadcaststyle), e.g., CUDA.Tiled(B) .= A .+ A.

@maleadt

maleadt commented Mar 19, 2026

Copy link
Copy Markdown
Member

Okay, I vibe coded something that looks like:

julia> using CUDA, cuTile

julia> import cuTile as ct

julia> ct.Tiled(CUDA.rand(1024,1024)) .= CUDA.rand(1024,1024) .* CUDA.rand(1024,1024)
CUDAExt._Tiled{CuArray{Float32, 2, CUDA.DeviceMemory}}(Float32[0.22044289 0.38459682 … 0.036794282 0.055302374; 0.3504372 0.6617892 … 0.28499892 0.029165484; … ; 0.12322262 0.16766846 … 0.30772275 0.5190717; 0.28138888 0.099202745 … 0.034274556 0.020667627])

julia> @device_code_tiled ct.Tiled(CUDA.rand(1024,1024)) .= CUDA.rand(1024,1024) .* CUDA.rand(1024,1024)
// _tiled_bc_kernel_2d(cuTile.TileArray{Float32, 2, cuTile.ArraySpec{2, 128, true, (0, 4), (32, 32)}()}, Base.Broadcast.Broadcasted{Nothing, Nothing, typeof(*), Tuple{cuTile.TileArray{Float32, 2, cuTile.ArraySpec{2, 128, true, (0, 4), (32, 32)}()}, cuTile.TileArray{Float32, 2, cuTile.ArraySpec{2, 128, true, (0, 4), (32, 32)}()}}}, cuTile.Constant{Tuple{Int64, Int64}, (64, 64)})

cuda_tile.module @kernels {
  entry @_tiled_bc_kernel_2d(%arg0: tile<ptr<f32>>, %arg1: tile<i32>, %arg2: tile<i32>, %arg3: tile<i32>, %arg4: tile<i32>, %arg5: tile<ptr<f32>>, %arg6: tile<i32>, %arg7: tile<i32>, %arg8: tile<i32>, %arg9: tile<i32>, %arg10: tile<ptr<f32>>, %arg11: tile<i32>, %arg12: tile<i32>, %arg13: tile<i32>, %arg14: tile<i32>) {
    %assume = assume div_by<128>, %arg0 : tile<ptr<f32>>
    %assume_0 = assume bounded<0, ?>, %arg1 : tile<i32>
    %assume_1 = assume bounded<0, ?>, %arg2 : tile<i32>
    %assume_2 = assume bounded<0, ?>, %arg4 : tile<i32>
    %assume_assume = assume div_by<32>, %assume_0 : tile<i32>
    %assume_assume_3 = assume div_by<32>, %assume_1 : tile<i32>
    %assume_assume_4 = assume div_by<4>, %assume_2 : tile<i32>
    %tview = make_tensor_view %assume, shape = [%assume_assume, %assume_assume_3], strides = [1, %assume_assume_4] : tile<i32> -> tensor_view<?x?xf32, strides=[1,?]>
    %assume_5 = assume div_by<128>, %arg5 : tile<ptr<f32>>
    %assume_6 = assume bounded<0, ?>, %arg6 : tile<i32>
    %assume_7 = assume bounded<0, ?>, %arg7 : tile<i32>
    %assume_8 = assume bounded<0, ?>, %arg9 : tile<i32>
    %assume_assume_9 = assume div_by<32>, %assume_6 : tile<i32>
    %assume_assume_10 = assume div_by<32>, %assume_7 : tile<i32>
    %assume_assume_11 = assume div_by<4>, %assume_8 : tile<i32>
    %tview_12 = make_tensor_view %assume_5, shape = [%assume_assume_9, %assume_assume_10], strides = [1, %assume_assume_11] : tile<i32> -> tensor_view<?x?xf32, strides=[1,?]>
    %assume_13 = assume div_by<128>, %arg10 : tile<ptr<f32>>
    %assume_14 = assume bounded<0, ?>, %arg11 : tile<i32>
    %assume_15 = assume bounded<0, ?>, %arg12 : tile<i32>
    %assume_16 = assume bounded<0, ?>, %arg14 : tile<i32>
    %assume_assume_17 = assume div_by<32>, %assume_14 : tile<i32>
    %assume_assume_18 = assume div_by<32>, %assume_15 : tile<i32>
    %assume_assume_19 = assume div_by<4>, %assume_16 : tile<i32>
    %tview_20 = make_tensor_view %assume_13, shape = [%assume_assume_17, %assume_assume_18], strides = [1, %assume_assume_19] : tile<i32> -> tensor_view<?x?xf32, strides=[1,?]>
    %0 = make_token : token
    %blockId_x, %blockId_y, %blockId_z = get_tile_block_id : tile<i32>
    %cst_1_i32 = constant <i32: 1> : tile<i32>
    %1 = addi %blockId_x, %cst_1_i32 : tile<i32>
    %blockId_x_21, %blockId_y_22, %blockId_z_23 = get_tile_block_id : tile<i32>
    %cst_1_i32_24 = constant <i32: 1> : tile<i32>
    %2 = addi %blockId_y_22, %cst_1_i32_24 : tile<i32>
    %false = constant <i1: false> : tile<i1>
    %pview = make_partition_view %tview_12 : partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>
    %cst_1_i32_25 = constant <i32: 1> : tile<i32>
    %3 = subi %1, %cst_1_i32_25 : tile<i32>
    %cst_1_i32_26 = constant <i32: 1> : tile<i32>
    %4 = subi %2, %cst_1_i32_26 : tile<i32>
    %tile, %result_token = load_view_tko weak %pview[%3, %4] token = %0 : partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>, tile<i32> -> tile<64x64xf32>, token
    %pview_27 = make_partition_view %tview_20 : partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>
    %cst_1_i32_28 = constant <i32: 1> : tile<i32>
    %5 = subi %1, %cst_1_i32_28 : tile<i32>
    %cst_1_i32_29 = constant <i32: 1> : tile<i32>
    %6 = subi %2, %cst_1_i32_29 : tile<i32>
    %tile_30, %result_token_31 = load_view_tko weak %pview_27[%5, %6] token = %result_token : partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>, tile<i32> -> tile<64x64xf32>, token
    %7 = mulf %tile, %tile_30  : tile<64x64xf32>
    %cst_1_i32_32 = constant <i32: 1> : tile<i32>
    %8 = subi %1, %cst_1_i32_32 : tile<i32>
    %cst_1_i32_33 = constant <i32: 1> : tile<i32>
    %9 = subi %2, %cst_1_i32_33 : tile<i32>
    %pview_34 = make_partition_view %tview : partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>
    %10 = store_view_tko weak %7, %pview_34[%8, %9] token = %result_token_31 : tile<64x64xf32>, partition_view<tile=(64x64), tensor_view<?x?xf32, strides=[1,?]>>, tile<i32> -> token
    return
  }
}

Pretty nice, no? And reuses much more of the broadcast machinery, without the need to generate kernels.

@maleadt

maleadt commented Mar 19, 2026

Copy link
Copy Markdown
Member

Interestingly, cuTile is already faster than our regular broadcast:

julia> A = CUDA.rand(1024, 1024);

julia> B = CUDA.rand(1024, 1024);

julia> C = CUDA.rand(1024, 1024);

julia> CUDA.@profile trace=true A .=+ B .* C
Profiler ran for 91.31 µs, capturing 67 events.

Host-side activity: calling CUDA APIs took 31.95 µs (34.99% of the trace)
┌────┬──────────┬──────────┬────────────────┐
│ ID │    Start │     Time │ Name           │
├────┼──────────┼──────────┼────────────────┤
│ 64 │ 57.22 µs │ 25.27 µs │ cuLaunchKernel │
└────┴──────────┴──────────┴────────────────┘

Device-side activity: GPU was busy for 10.73 µs (11.75% of the trace)
┌────┬──────────┬──────────┬─────────┬────────┬──────┬─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│ ID │    Start │     Time │ Threads │ Blocks │ Regs │ Name                                                                                                                                                           ⋯
├────┼──────────┼──────────┼─────────┼────────┼──────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│ 64 │ 79.87 µs │ 10.73 µs │     768 │   2048 │   28 │ gpu_broadcast_kernel_cartesian(CompilerMetadata<DynamicSize, DynamicCheck, void, CartesianIndices<2, Tuple<OneTo<Int64>, OneTo<Int64>>>, NDRange<2, DynamicSiz ⋯
└────┴──────────┴──────────┴─────────┴────────┴──────┴─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                                                                                                                       1 column omitted


julia> CUDA.@profile trace=true ct.Tiled(A) .=+ B .* C
Profiler ran for 142.57 µs, capturing 25 events.

Host-side activity: calling CUDA APIs took 31.47 µs (22.07% of the trace)
┌────┬───────────┬──────────┬────────────────┐
│ ID │     Start │     Time │ Name           │
├────┼───────────┼──────────┼────────────────┤
│ 22 │ 105.86 µs │ 29.09 µs │ cuLaunchKernel │
└────┴───────────┴──────────┴────────────────┘

Device-side activity: GPU was busy for 4.05 µs (2.84% of the trace)
┌────┬───────────┬─────────┬─────────┬────────┬──────┬─────────────────────┐
│ ID │     Start │    Time │ Threads │ Blocks │ Regs │ Name                │
├────┼───────────┼─────────┼─────────┼────────┼──────┼─────────────────────┤
│ 22 │ 137.33 µs │ 4.05 µs │     128 │  16×16 │   84 │ _tiled_bc_kernel_2d │
└────┴───────────┴─────────┴─────────┴────────┴──────┴─────────────────────┘

Maybe not too surprising though, since we do some forced specialization on broadcast arguments. I wonder if we should consider backporting the ct.Const logic to CUDA.jl's SIMT kernels.

@maleadt

maleadt commented Mar 20, 2026

Copy link
Copy Markdown
Member

I moved my code into #129, since it was entirely disconnected from the work on this branch.

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

All good! Looks super sleek, and is going in the right direction.

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

Superceded by #129

@AntonOresten AntonOresten deleted the fused-broadcast branch March 20, 2026 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants