Skip to content

Commit bf6489b

Browse files
FFT: 100x speedup by cutting the BS (#232)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 3fe82c6 commit bf6489b

3 files changed

Lines changed: 43 additions & 55 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras`
104104
| Layer Norm bwd | 4096² f32 | 246 GB/s | 251 GB/s | OK (-2%) |
105105
| Matrix Multiplication | 4096³ f32 | 47.4 TFLOPS | 43.5 TFLOPS | +9% |
106106
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 34.2 TFLOPS | 30.9 TFLOPS | +11% |
107-
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 545 μs | 550 μs | OK (+1%) |
107+
| FFT (3-stage Cooley-Tukey) | 4096-pt ×256 c64 | 209 μs | 204 μs | OK (-2%) |
108108
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 27.7 TFLOPS | 20.3 TFLOPS | +36% |
109109
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 102.7 TFLOPS | 63.3 TFLOPS | +62% |
110110
| Softmax (TMA) | 4096² f32 | 838 GB/s | 843 GB/s | OK (-1%) |

examples/fft.jl

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,29 @@ using FFTW
2121
# Python left-multiply W @ X ↔ Julia right-multiply X * W (batch dims trailing).
2222
# Python ct.permute(x, (0,2,3,1)) ↔ Julia permutedims(x, (3,1,2,4)).
2323
function fft_kernel(
24-
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, N2D, BS)
25-
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, N2D, BS)
24+
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, 2N ÷ D, BS)
25+
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, 2N ÷ D, BS)
2626
W0::ct.TileArray{Float32, 3}, # W0 (2, F0, F0) DFT matrix
2727
W1::ct.TileArray{Float32, 3}, # W1 (2, F1, F1)
2828
W2::ct.TileArray{Float32, 3}, # W2 (2, F2, F2)
2929
T0::ct.TileArray{Float32, 3}, # T0 (2, F1F2, F0) twiddle factors
3030
T1::ct.TileArray{Float32, 3}, # T1 (2, F2, F1) twiddle factors
31-
n_const::Int,
32-
f0_const::Int,
33-
f1_const::Int,
34-
f2_const::Int,
35-
f0f1_const::Int,
36-
f1f2_const::Int,
37-
f0f2_const::Int,
38-
bs_const::Int,
39-
d_const::Int,
40-
n2d_const::Int
31+
N::Int,
32+
F0::Int,
33+
F1::Int,
34+
F2::Int,
35+
BS::Int,
36+
D::Int,
4137
)
42-
N = n_const
43-
F0 = f0_const
44-
F1 = f1_const
45-
F2 = f2_const
46-
F0F1 = f0f1_const
47-
F1F2 = f1f2_const
48-
F0F2 = f0f2_const
49-
BS = bs_const
50-
D = d_const
51-
N2D = n2d_const
38+
F0F1 = F0 * F1
39+
F1F2 = F1 * F2
40+
F0F2 = F0 * F2
5241

5342
bid = ct.bid(1)
5443

5544
# --- Load Input Data ---
56-
# Input is (D, N2D, BS). Load and reshape to (2, N, BS).
57-
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS))
45+
# Input is (D, 2N ÷ D, BS). Load and reshape to (2, N, BS).
46+
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, 2N ÷ D, BS)), (2, N, BS))
5847

5948
# Split real and imaginary parts, reshape to 4D factored form
6049
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS))
@@ -131,7 +120,7 @@ function fft_kernel(
131120
# --- Concatenate and Store ---
132121
X_r_final = reshape(X_r10, (1, N, BS))
133122
X_i_final = reshape(X_i10, (1, N, BS))
134-
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS))
123+
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, 2N ÷ D, BS))
135124
ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri)
136125

137126
return
@@ -193,14 +182,14 @@ end
193182
=============================================================================#
194183

195184
function prepare(; benchmark::Bool=false,
196-
batch::Int=benchmark ? 64 : 2,
197-
factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2),
185+
batch::Int=benchmark ? 256 : 2,
186+
factors::NTuple{3,Int}=benchmark ? (16, 16, 16) : (2, 2, 2),
198187
atom_packing_dim::Int=min(64, 2 * prod(factors)))
199-
n = prod(factors)
200-
@assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim"
188+
N = prod(factors)
189+
@assert 2N % atom_packing_dim == 0 "2 * N must be divisible by atom_packing_dim"
201190

202191
cuRAND.seed!(42)
203-
input = cuRAND.randn(ComplexF32, n, batch)
192+
input = cuRAND.randn(ComplexF32, N, batch)
204193

205194
W0, W1, W2, T0, T1 = make_twiddles(factors)
206195
W0_gpu = CuArray(W0)
@@ -210,46 +199,43 @@ function prepare(; benchmark::Bool=false,
210199
T1_gpu = CuArray(T1)
211200

212201
D = atom_packing_dim
213-
N2D = n * 2 ÷ D
214-
# Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
215-
# When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
216-
x_ri = reinterpret(reshape, Float32, input) # (2, n, batch)
217-
x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch)
218-
y_packed = CuArray{Float32}(undef, D, N2D, batch)
202+
# Pack complex input as (D, 2N ÷ D, batch) Float32 — matches Python's (batch, 2N ÷ D, D) row-major.
203+
# When D=2, reinterpret gives (2, N, batch) directly. For D>2, reshape the flat layout.
204+
x_ri = reinterpret(reshape, Float32, input) # (2, N, batch)
205+
x_packed = D == 2 ? x_ri : reshape(x_ri, D, 2N ÷ D, batch)
206+
y_packed = CuArray{Float32}(undef, D, 2N ÷ D, batch)
219207

220208
return (;
221209
input, x_packed, y_packed,
222210
W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
223-
factors, batch, n, D, N2D
211+
factors, batch, N, D
224212
)
225213
end
226214

227215
function run(data; nruns::Int=1, warmup::Int=0)
228216
(; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
229-
factors, batch, n, D, N2D) = data
217+
factors, batch, N, D) = data
230218

231219
F0, F1, F2 = factors
232-
F0F1 = F0 * F1
233-
F1F2 = F1 * F2
234-
F0F2 = F0 * F2
235-
grid = (batch, 1, 1)
220+
BS = 1
221+
grid = (batch ÷ BS, 1, 1)
236222

237223
CUDACore.@sync for _ in 1:warmup
238-
@cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D))
224+
@cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D))
239225
end
240226

241227
times = Float64[]
242228
NVTX.@range "cuTile" begin
243229
for i in 1:nruns
244230
NVTX.@range "run $i" begin
245-
t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D))
231+
t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D))
246232
push!(times, t * 1000) # ms
247233
end
248234
end
249235
end
250236

251-
# Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch)
252-
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch)
237+
# Unpack output: (D, 2n ÷ D, batch) → (2, N, batch) → ComplexF32(n, batch)
238+
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, N, batch)
253239
y_complex = reinterpret(reshape, ComplexF32, y_ri)
254240
output = copy(y_complex)
255241

@@ -272,18 +258,19 @@ end
272258
=============================================================================#
273259

274260
function run_others(data; nruns::Int=1, warmup::Int=0)
275-
(; input, batch, n) = data
261+
(; input, batch, N) = data
276262
results = Dict{String, Vector{Float64}}()
277263

264+
plan = cuFFT.plan_fft!(input, 1)
278265
CUDACore.@sync for _ in 1:warmup
279-
cuFFT.fft!(copy(input), 1)
266+
plan * copy(input)
280267
end
281268
times_cufft = Float64[]
282269
NVTX.@range "cuFFT" begin
283270
for i in 1:nruns
284271
NVTX.@range "run $i" begin
285272
input_copy = copy(input)
286-
t = CUDACore.@elapsed cuFFT.fft!(input_copy, 1)
273+
t = CUDACore.@elapsed plan * input_copy
287274
push!(times_cufft, t * 1000)
288275
end
289276
end

examples/fft.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def fft_make_twiddles(factors, precision, device):
115115
def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None):
116116
"""Allocate and initialize data for FFT."""
117117
if batch is None:
118-
batch = 64 if benchmark else 2
118+
batch = 256 if benchmark else 2
119119
if factors is None:
120-
factors = (8, 8, 8) if benchmark else (2, 2, 2)
120+
factors = (16, 16, 16) if benchmark else (2, 2, 2)
121121
F0, F1, F2 = factors
122122
N = F0 * F1 * F2
123123
D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim
@@ -152,12 +152,13 @@ def run(data, *, nruns: int = 1, warmup: int = 0):
152152
F0, F1, F2 = data["factors"]
153153
batch, N, D = data["batch"], data["N"], data["D"]
154154

155-
grid = (batch, 1, 1)
155+
BS = 1
156+
grid = (batch // BS, 1, 1)
156157

157158
# Warmup
158159
for _ in range(warmup):
159160
ct.launch(torch.cuda.current_stream(), grid, fft_kernel,
160-
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D))
161+
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D))
161162
torch.cuda.synchronize()
162163

163164
# Timed runs
@@ -169,7 +170,7 @@ def run(data, *, nruns: int = 1, warmup: int = 0):
169170
end = torch.cuda.Event(enable_timing=True)
170171
start.record()
171172
ct.launch(torch.cuda.current_stream(), grid, fft_kernel,
172-
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D))
173+
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D))
173174
end.record()
174175
torch.cuda.synchronize()
175176
times.append(start.elapsed_time(end)) # ms

0 commit comments

Comments
 (0)