Skip to content

Commit 9c3e09f

Browse files
authored
Merge pull request #228 from JuliaGPU/tb/examples
Make examples slightly more idiomatic and aligned
2 parents 785d73d + 973783c commit 9c3e09f

6 files changed

Lines changed: 55 additions & 119 deletions

File tree

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,17 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras`
9898

9999
| Kernel | Size | Julia | Python | Status |
100100
|--------|------|-------|--------|--------|
101-
| Vector Addition | 2^27 f32 | 845 GB/s | 846 GB/s | OK (=) |
101+
| Vector Addition | 2^27 f32 | 844 GB/s | 845 GB/s | OK (=) |
102102
| Matrix Transpose | 8192² f32 | 812 GB/s | 814 GB/s | OK (=) |
103-
| Layer Norm fwd | 4096² f32 | 983 GB/s | 716 GB/s | +37% |
104-
| Layer Norm bwd | 4096² f32 | 248 GB/s | 251 GB/s | OK (-1%) |
105-
| Matrix Multiplication | 4096³ f32 | 47.5 TFLOPS | 43.5 TFLOPS | +9% |
106-
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 34.0 TFLOPS | 30.8 TFLOPS | +10% |
107-
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 529 μs | 554 μs | +5% |
108-
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 27.0 TFLOPS | 20.1 TFLOPS | +34% |
109-
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 103.6 TFLOPS | 63.4 TFLOPS | +63% |
110-
| Softmax (TMA) | 4096² f32 | 849 GB/s | 857 GB/s | OK (-1%) |
111-
| Softmax (Chunked) | 4096² f32 | 1684 GB/s | 1640 GB/s | OK (+3%) |
103+
| Layer Norm fwd | 4096² f32 | 986 GB/s | 716 GB/s | +38% |
104+
| Layer Norm bwd | 4096² f32 | 246 GB/s | 251 GB/s | OK (-2%) |
105+
| Matrix Multiplication | 4096³ f32 | 47.4 TFLOPS | 43.5 TFLOPS | +9% |
106+
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 34.2 TFLOPS | 30.9 TFLOPS | +11% |
107+
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 545 μs | 550 μs | OK (+1%) |
108+
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 27.7 TFLOPS | 20.3 TFLOPS | +36% |
109+
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 102.7 TFLOPS | 63.3 TFLOPS | +62% |
110+
| Softmax (TMA) | 4096² f32 | 838 GB/s | 843 GB/s | OK (-1%) |
111+
| Softmax (Chunked) | 4096² f32 | 1672 GB/s | 1636 GB/s | OK (+2%) |
112112

113113

114114
## Supported Operations

examples/fmha.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ function fmha_kernel(Q::ct.TileArray{T, 4}, K::ct.TileArray{T, 4},
2828
CAUSAL::Bool, EVEN_K::Bool) where {T}
2929
ct.@compiler_options occupancy=2
3030

31-
# Map block IDs to batch and head indices
32-
# Julia: bid(1) = x (seq tiles), bid(2) = y (batch * heads)
31+
# Map block IDs to batch and head indices.
32+
# Julia: bid(1) = x (seq tiles), bid(2) = y (batch * heads).
33+
# `cld`/`mod1` give 1-indexed batch/head/kv-head straight from the
34+
# 1-indexed bid_y, with no 0-indexed detour.
3335
bid_x = ct.bid(1)
34-
bid_y = ct.bid(2) - Int32(1) # 0-indexed for div/mod arithmetic
35-
batch_idx = fld(bid_y, Int32(H)) + Int32(1)
36-
head_idx = rem(bid_y, Int32(H)) + Int32(1)
37-
off_kv_h = fld(head_idx - Int32(1), Int32(QUERY_GROUP_SIZE)) + Int32(1)
36+
bid_y = ct.bid(2)
37+
batch_idx = cld(bid_y, Int32(H))
38+
head_idx = mod1(bid_y, Int32(H))
39+
off_kv_h = cld(head_idx, Int32(QUERY_GROUP_SIZE))
3840

3941
# Adjust qk_scale for exp2
4042
qk_scale = qk_scale * INV_LOG_2

examples/layernorm.jl

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,14 @@ end
7171
#=============================================================================
7272
LayerNorm Backward Kernels
7373
74-
Backward pass: computes gradients for LayerNorm.
75-
The full backward pass has two kernels:
76-
1. layer_norm_bwd_dx - Computes dX (gradient with respect to input)
77-
2. layer_norm_bwd_dwdb - Computes dW and dB (requires atomic accumulation)
78-
79-
For now, we implement a simplified backward that just computes dX.
74+
Backward pass: computes gradients for LayerNorm via two kernels:
75+
1. layer_norm_bwd_dx_partial_dwdb - dX + per-group partial dW/dB (atomic).
76+
2. layer_norm_bwd_dwdb - final reduction of partial dW/dB.
8077
=============================================================================#
8178

82-
"""
83-
Helper function for backward pass - loads data and computes common terms.
84-
This gets inlined by Julia's compiler.
85-
bid_m and j are 1-indexed (block ID and tile index).
86-
"""
87-
@inline function bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
79+
# Helper function for backward pass - loads data and computes common terms.
80+
# bid_m and j are 1-indexed (block ID and tile index).
81+
function bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
8882
tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero)
8983
tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1))
9084
tdy = ct.load(DY; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero)
@@ -103,53 +97,6 @@ bid_m and j are 1-indexed (block ID and tile index).
10397
return tdy, xhat_masked, wdy_masked
10498
end
10599

106-
"""
107-
layer_norm_bwd_dx(DX, DY, X, W, Mean, Rstd, TILE_N)
108-
109-
Backward pass: computes gradient with respect to input X.
110-
111-
Args:
112-
DX: Output gradient with respect to X (N, M).
113-
DY: Input gradient with respect to Y (N, M).
114-
X: Input tensor (N, M).
115-
W: Weight tensor (N,).
116-
Mean: Mean tensor (M,).
117-
Rstd: Reciprocal standard deviation tensor (M,).
118-
TILE_N: Tile size along N dimension.
119-
"""
120-
function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float32, 2},
121-
X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1},
122-
Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1},
123-
TILE_N::Int)
124-
bid_m = ct.bid(1)
125-
num_tiles = ct.num_tiles(X, 1, (TILE_N, 1))
126-
N = size(X, 1)
127-
128-
# Load mean and rstd for this row
129-
mean = ct.load(Mean; index=bid_m, shape=(1,), padding_mode=ct.PaddingMode.Zero)
130-
rstd = ct.load(Rstd; index=bid_m, shape=(1,), padding_mode=ct.PaddingMode.Zero)
131-
132-
# First pass: compute c1 and c2 reduction terms
133-
c1 = zeros(Float32, (TILE_N, 1))
134-
c2 = zeros(Float32, (TILE_N, 1))
135-
for j in Int32(1):num_tiles
136-
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
137-
c1 = c1 .+ (xhat .* wdy)
138-
c2 = c2 .+ wdy
139-
end
140-
c1 = sum(c1; dims=1) / N
141-
c2 = sum(c2; dims=1) / N
142-
143-
# Second pass: compute dX
144-
for j in Int32(1):num_tiles
145-
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
146-
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
147-
ct.store(DX; index=(j, bid_m), tile=tdx)
148-
end
149-
150-
return
151-
end
152-
153100
"""
154101
layer_norm_bwd_dx_partial_dwdb(DX, DY, DW, DB, X, W, Mean, Rstd, Locks, GROUP_SIZE_M, TILE_N)
155102

examples/matmul.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,19 @@ using LinearAlgebra
88
using cuTile: cuTile
99
import cuTile as ct
1010

11-
# 2D swizzle for better L2 cache locality
12-
# Groups blocks to access nearby memory regions together
13-
@inline function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid)
11+
# 2D swizzle for better L2 cache locality. Takes a 1-indexed bid and
12+
# returns 1-indexed (bid_m, bid_n). Modular arithmetic is done on the
13+
# 0-indexed bid internally; the conversion is contained in this helper.
14+
function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid)
1415
num_bid_m = cld(M, Int32(tm))
1516
num_bid_n = cld(N, Int32(tn))
1617
num_bid_in_group = Int32(GROUP_SIZE_M) * num_bid_n
17-
group_id = fld(bid, num_bid_in_group)
18+
bid0 = bid - Int32(1)
19+
group_id = fld(bid0, num_bid_in_group)
1820
first_bid_m = group_id * Int32(GROUP_SIZE_M)
1921
group_size_m = min(num_bid_m - first_bid_m, Int32(GROUP_SIZE_M))
20-
bid_m = first_bid_m + rem(bid, group_size_m)
21-
bid_n = fld(rem(bid, num_bid_in_group), group_size_m)
22+
bid_m = first_bid_m + rem(bid0, group_size_m) + Int32(1)
23+
bid_n = fld(rem(bid0, num_bid_in_group), group_size_m) + Int32(1)
2224
return bid_m, bid_n
2325
end
2426

@@ -31,11 +33,7 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr
3133
bid = ct.bid(1)
3234
M = size(A, 1)
3335
N = size(B, 2)
34-
# swizzle_2d expects 0-indexed bid, returns 0-indexed tile coords
35-
bid_m_0, bid_n_0 = swizzle_2d(M, N, tm, tn, 8, bid - Int32(1))
36-
# Convert to 1-indexed tile coordinates
37-
bid_m = bid_m_0 + Int32(1)
38-
bid_n = bid_n_0 + Int32(1)
36+
bid_m, bid_n = swizzle_2d(M, N, tm, tn, 8, bid)
3937

4038
# Number of K tiles to iterate over
4139
num_k = ct.num_tiles(A, 2, (tm, tk))

examples/moe.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ import cuTile as ct
1818
Helper: 2D swizzle (same pattern as matmul.jl)
1919
=============================================================================#
2020

21-
@inline function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid)
21+
function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid)
2222
num_bid_m = cld(M, Int32(tm))
2323
num_bid_n = cld(N, Int32(tn))
2424
num_bid_in_group = Int32(GROUP_SIZE_M) * num_bid_n
25-
group_id = fld(bid, num_bid_in_group)
25+
bid0 = bid - Int32(1)
26+
group_id = fld(bid0, num_bid_in_group)
2627
first_bid_m = group_id * Int32(GROUP_SIZE_M)
2728
group_size_m = min(num_bid_m - first_bid_m, Int32(GROUP_SIZE_M))
28-
bid_m = first_bid_m + rem(bid, group_size_m)
29-
bid_n = fld(rem(bid, num_bid_in_group), group_size_m)
29+
bid_m = first_bid_m + rem(bid0, group_size_m) + Int32(1)
30+
bid_n = fld(rem(bid0, num_bid_in_group), group_size_m) + Int32(1)
3031
return bid_m, bid_n
3132
end
3233

@@ -53,19 +54,18 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
5354
K = size(B, 1)
5455
N = size(B, 2)
5556

56-
bid = ct.bid(1) - Int32(1) # 0-indexed for swizzle
57+
bid = ct.bid(1)
5758
bid_m, bid_n = swizzle_2d(M, N, TILE_M, TILE_N, Int32(8), bid)
5859

5960
# Gather 1-indexed token IDs for this block
60-
token_id_indices = bid_m * Int32(TILE_M) .+ ct.arange(TILE_M)
61+
token_id_indices = (bid_m - Int32(1)) * Int32(TILE_M) .+ ct.arange(TILE_M)
6162
token_ids = ct.gather(sorted_token_ids, token_id_indices)
6263

63-
# Map 1-indexed flat token_id to 1-indexed column in A
64-
# token_id k → original token = (k-1) ÷ num_token_replicas + 1
65-
a_tok_indices = (token_ids .- Int32(1)) Int32(num_token_replicas) .+ Int32(1)
64+
# 1-indexed flat token_id 1-indexed column in A. Each original token
65+
# has `num_token_replicas` consecutive ids; ceil-divide recovers it.
66+
a_tok_indices = cld.(token_ids, Int32(num_token_replicas))
6667

67-
# Expert for this block (scalar, 1-indexed tile index for load)
68-
expert_id = sorted_expert_ids[bid_m + Int32(1)]
68+
expert_id = sorted_expert_ids[bid_m]
6969

7070
acc = zeros(Float32, TILE_N, TILE_M)
7171
num_k = cld(K, Int32(TILE_K))
@@ -81,7 +81,7 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
8181
# B is (K, N, num_experts): load (TILE_N, TILE_K) slice for this expert
8282
# order=(2,1,3) folds the transpose into the load via dim_map, matching
8383
# Python cuTile's order=(0,2,1) and avoiding an explicit permute in Tile IR.
84-
b = ct.load(B; index=(bid_n + Int32(1), k, expert_id),
84+
b = ct.load(B; index=(bid_n, k, expert_id),
8585
shape=(TILE_N, TILE_K, 1), order=(2, 1, 3),
8686
padding_mode=ct.PaddingMode.Zero)
8787
b = reshape(b, (TILE_N, TILE_K))
@@ -97,7 +97,7 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
9797

9898
# Scatter result to C at token_id positions
9999
# C is (N, total_tokens): dim 1 = N, dim 2 = tokens
100-
c_n_indices = bid_n * Int32(TILE_N) .+ ct.arange(TILE_N) # 1-indexed
100+
c_n_indices = (bid_n - Int32(1)) * Int32(TILE_N) .+ ct.arange(TILE_N) # 1-indexed
101101
output = convert(ct.Tile{T}, acc)
102102
ct.scatter(C, (reshape(c_n_indices, (TILE_N, 1)),
103103
reshape(token_ids, (1, TILE_M))), output)
@@ -204,24 +204,12 @@ function invoke_silu_and_mul_kernel(AB, C)
204204
inter = size(C, 1) # C is (intermediate, total_tokens)
205205
total = size(AB, 2)
206206

207-
# Split AB(inter*2, total) into gate and up halves along dim 1.
208-
#A_half = AB[1:inter, :]
209-
#B_half = AB[inter+1:2*inter, :]
210-
# FIXME: CUDA.jl's CuArray indexing (AB[1:inter, :]) uses a slow generic kernel.
211-
# Use unsafe_copy2d! (cuMemcpy2D) for hardware-accelerated pitched 2D copy instead.
212-
T = eltype(AB)
213-
A_half = similar(AB, inter, total)
214-
B_half = similar(AB, inter, total)
215-
src_pitch = size(AB, 1) * sizeof(T)
216-
dst_pitch = inter * sizeof(T)
217-
CUDACore.unsafe_copy2d!(pointer(A_half), CUDACore.DeviceMemory,
218-
pointer(AB), CUDACore.DeviceMemory,
219-
inter, total; srcPitch=src_pitch, dstPitch=dst_pitch,
220-
async=true)
221-
CUDACore.unsafe_copy2d!(pointer(B_half), CUDACore.DeviceMemory,
222-
pointer(AB) + inter * sizeof(T), CUDACore.DeviceMemory,
223-
inter, total; srcPitch=src_pitch, dstPitch=dst_pitch,
224-
async=true)
207+
# Split AB(inter*2, total) into gate and up halves along dim 1 — mirrors
208+
# cuTile Python's `AB.chunk(2, dim=-1)`. Views are non-contiguous along
209+
# dim 2 but each block only loads a (TILE_N, 1) tile, so codegen is
210+
# unaffected.
211+
A_half = view(AB, 1:inter, :)
212+
B_half = view(AB, (inter + 1):(2 * inter), :)
225213

226214
tile_n = nextpow(2, inter)
227215
@cuda backend=cuTile blocks=total silu_and_mul_kernel(A_half, B_half, C, ct.Constant(tile_n))

src/launch.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ struct KernelAdaptor end
2727
Adapt.adapt_storage(::KernelAdaptor, arr::AbstractArray) = TileArray(arr)
2828
Adapt.adapt_storage(::KernelAdaptor, t::Type) = Constant(t)
2929

30-
# Adapt's default `adapt_structure(to, ::PermutedDimsArray)` recurses by
31-
# rebuilding `PermutedDimsArray(adapt(parent), perm)`. We can't follow that
30+
# Adapt's defaults for `PermutedDimsArray` and `SubArray` recurse by
31+
# rebuilding the wrapper around `adapt(parent)`. We can't follow that
3232
# pattern because `TileArray` isn't `<:AbstractArray` — strided-wrapper
3333
# state is absorbed into its `sizes`/`strides` fields directly. Short-circuit
3434
# the recursion so the whole wrapper becomes a single TileArray.
3535
Adapt.adapt_structure(::KernelAdaptor, arr::PermutedDimsArray) = TileArray(arr)
36+
Adapt.adapt_structure(::KernelAdaptor, arr::SubArray) = TileArray(arr)
3637

3738
"""
3839
cuTileconvert(x)

0 commit comments

Comments
 (0)