Skip to content

Commit dd600c8

Browse files
authored
Align FFT examples. (#145)
1 parent b9a9b91 commit dd600c8

3 files changed

Lines changed: 29 additions & 40 deletions

File tree

README.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,12 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080:
9696

9797
| Kernel | Julia | Python | Status |
9898
|--------|-------|--------|--------|
99-
| Vector Addition | 840 GB/s | 844 GB/s | OK (=) |
100-
| Matrix Transpose | 806 GB/s | 816 GB/s | OK (-1%) |
101-
| Layer Normalization | 1074 GB/s | 761 GB/s | OK (+41%) |
102-
| Matrix Multiplication | 36.8 TFLOPS | 50.7 TFLOPS | -27% |
103-
| Batch Matrix Multiply | 28.3 TFLOPS | 40.0 TFLOPS | -29% |
104-
| FFT (3-stage Cooley-Tukey) | 571 μs | 192 μs | -66% |
105-
106-
Memory-bound kernels (vadd, transpose, layernorm) match or beat Python. Compute-intensive
107-
kernels (matmul, batch matmul, FFT) are slower due to conservative token threading in the
108-
generated Tile IR, which serializes loads that could otherwise be pipelined.
99+
| Vector Addition | 841 GB/s | 847 GB/s | OK (=) |
100+
| Matrix Transpose | 807 GB/s | 813 GB/s | OK (-1%) |
101+
| Layer Normalization | 653 GB/s | 758 GB/s | -14% |
102+
| Matrix Multiplication | 43.1 TFLOPS | 50.3 TFLOPS | -14% |
103+
| Batch Matrix Multiply | 30.4 TFLOPS | 40.0 TFLOPS | -24% |
104+
| FFT (3-stage Cooley-Tukey) | 620 μs | 486 μs | -28% |
109105

110106

111107
## Supported Operations

examples/fft.jl

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function fft_kernel(
5252

5353
# --- Load Input Data ---
5454
# Input is (D, N2D, BS). Load and reshape to (2, N, BS).
55-
X_ri = reshape(ct.load(x_packed_in; index=(1, 1, bid), shape=(D, N2D, BS)), (2, N, BS))
55+
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS))
5656

5757
# Split real and imaginary parts, reshape to 4D factored form
5858
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS))
@@ -130,7 +130,7 @@ function fft_kernel(
130130
X_r_final = reshape(X_r10, (1, N, BS))
131131
X_i_final = reshape(X_i10, (1, N, BS))
132132
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS))
133-
ct.store(y_packed_out; index=(1, 1, bid), tile=Y_ri)
133+
ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri)
134134

135135
return
136136
end
@@ -192,15 +192,12 @@ end
192192

193193
function prepare(; benchmark::Bool=false,
194194
batch::Int=benchmark ? 64 : 2,
195-
n::Int=benchmark ? 512 : 8,
196195
factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2),
197-
atom_packing_dim::Int=2)
198-
@assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N"
196+
atom_packing_dim::Int=min(64, 2 * prod(factors)))
197+
n = prod(factors)
199198
@assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim"
200199

201200
CUDA.seed!(42)
202-
# Store as (n, batch) so reinterpret gives (2, n, batch) = (D, N2D, batch)
203-
# This matches Python's (batch, N2D, D) row-major in memory.
204201
input = CUDA.randn(ComplexF32, n, batch)
205202

206203
W0, W1, W2, T0, T1 = make_twiddles(factors)
@@ -212,7 +209,10 @@ function prepare(; benchmark::Bool=false,
212209

213210
D = atom_packing_dim
214211
N2D = n * 2 ÷ D
215-
x_packed = reinterpret(reshape, Float32, input) # (2, n, batch) = (D, N2D, batch)
212+
# Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
213+
# When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
214+
x_ri = reinterpret(reshape, Float32, input) # (2, n, batch)
215+
x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch)
216216
y_packed = CuArray{Float32}(undef, D, N2D, batch)
217217

218218
return (;
@@ -252,8 +252,9 @@ function run(data; nruns::Int=1, warmup::Int=0)
252252
push!(times, t * 1000) # ms
253253
end
254254

255-
# Unpack output: (2, n, batch) → ComplexF32(n, batch)
256-
y_complex = reinterpret(reshape, ComplexF32, y_packed)
255+
# Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch)
256+
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch)
257+
y_complex = reinterpret(reshape, ComplexF32, y_ri)
257258
output = copy(y_complex)
258259

259260
return (; output, times)
@@ -294,18 +295,12 @@ end
294295
function main()
295296
println("--- Running cuTile FFT Example ---")
296297

297-
BATCH_SIZE = 2
298-
FFT_SIZE = 8
299-
FFT_FACTORS = (2, 2, 2)
300-
ATOM_PACKING_DIM = 2
301-
298+
data = prepare()
302299
println(" Configuration:")
303-
println(" FFT Size (N): $FFT_SIZE")
304-
println(" Batch Size: $BATCH_SIZE")
305-
println(" FFT Factors: $FFT_FACTORS")
306-
println(" Atom Packing Dim: $ATOM_PACKING_DIM")
307-
308-
data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM)
300+
println(" FFT Size (N): $(data.n)")
301+
println(" Batch Size: $(data.batch)")
302+
println(" FFT Factors: $(data.factors)")
303+
println(" Atom Packing Dim: $(data.D)")
309304
println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))")
310305

311306
result = run(data)

examples/fft.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,15 @@ def fft_make_twiddles(factors, precision, device):
111111
# Example harness
112112
#=============================================================================
113113

114-
def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, factors: tuple = None, atom_packing_dim: int = 2):
114+
def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None):
115115
"""Allocate and initialize data for FFT."""
116116
if batch is None:
117117
batch = 64 if benchmark else 2
118118
if factors is None:
119119
factors = (8, 8, 8) if benchmark else (2, 2, 2)
120120
F0, F1, F2 = factors
121121
N = F0 * F1 * F2
122-
if size is None:
123-
size = N
124-
assert size == N, f"size ({size}) must equal product of factors ({N})"
125-
D = atom_packing_dim
122+
D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim
126123

127124
input_data = torch.randn(batch, N, dtype=torch.complex64, device='cuda')
128125

@@ -218,11 +215,12 @@ def run_others(data, *, nruns: int = 1, warmup: int = 0):
218215
# Main
219216
#=============================================================================
220217

221-
def test_fft(batch, size, factors, name=None):
218+
def test_fft(batch, factors, name=None):
222219
"""Test FFT with given parameters."""
220+
size = factors[0] * factors[1] * factors[2]
223221
name = name or f"fft batch={batch}, size={size}, factors={factors}"
224222
print(f"--- {name} ---")
225-
data = prepare(batch=batch, size=size, factors=factors)
223+
data = prepare(batch=batch, factors=factors)
226224
result = run(data)
227225
verify(data, result)
228226
print(" passed")
@@ -231,8 +229,8 @@ def test_fft(batch, size, factors, name=None):
231229
def main():
232230
print("--- cuTile FFT Examples ---\n")
233231

234-
test_fft(64, 512, (8, 8, 8))
235-
test_fft(32, 512, (8, 8, 8))
232+
test_fft(64, (8, 8, 8))
233+
test_fft(32, (8, 8, 8))
236234

237235
print("\n--- All FFT examples completed ---")
238236

0 commit comments

Comments
 (0)