@@ -21,40 +21,29 @@ using FFTW
2121# Python left-multiply W @ X ↔ Julia right-multiply X * W (batch dims trailing).
2222# Python ct.permute(x, (0,2,3,1)) ↔ Julia permutedims(x, (3,1,2,4)).
2323function fft_kernel (
24- x_packed_in:: ct.TileArray{Float32, 3} , # Input (D, N2D , BS)
25- y_packed_out:: ct.TileArray{Float32, 3} , # Output (D, N2D , BS)
24+ x_packed_in:: ct.TileArray{Float32, 3} , # Input (D, 2N ÷ D , BS)
25+ y_packed_out:: ct.TileArray{Float32, 3} , # Output (D, 2N ÷ D , BS)
2626 W0:: ct.TileArray{Float32, 3} , # W0 (2, F0, F0) DFT matrix
2727 W1:: ct.TileArray{Float32, 3} , # W1 (2, F1, F1)
2828 W2:: ct.TileArray{Float32, 3} , # W2 (2, F2, F2)
2929 T0:: ct.TileArray{Float32, 3} , # T0 (2, F1F2, F0) twiddle factors
3030 T1:: ct.TileArray{Float32, 3} , # T1 (2, F2, F1) twiddle factors
31- n_const:: Int ,
32- f0_const:: Int ,
33- f1_const:: Int ,
34- f2_const:: Int ,
35- f0f1_const:: Int ,
36- f1f2_const:: Int ,
37- f0f2_const:: Int ,
38- bs_const:: Int ,
39- d_const:: Int ,
40- n2d_const:: Int
31+ N:: Int ,
32+ F0:: Int ,
33+ F1:: Int ,
34+ F2:: Int ,
35+ BS:: Int ,
36+ D:: Int ,
4137)
42- N = n_const
43- F0 = f0_const
44- F1 = f1_const
45- F2 = f2_const
46- F0F1 = f0f1_const
47- F1F2 = f1f2_const
48- F0F2 = f0f2_const
49- BS = bs_const
50- D = d_const
51- N2D = n2d_const
38+ F0F1 = F0 * F1
39+ F1F2 = F1 * F2
40+ F0F2 = F0 * F2
5241
5342 bid = ct. bid (1 )
5443
5544 # --- Load Input Data ---
56- # Input is (D, N2D , BS). Load and reshape to (2, N, BS).
57- X_ri = reshape (ct. load (x_packed_in; index= (Int32 (1 ), Int32 (1 ), bid), shape= (D, N2D , BS)), (2 , N, BS))
45+ # Input is (D, 2N ÷ D , BS). Load and reshape to (2, N, BS).
46+ X_ri = reshape (ct. load (x_packed_in; index= (Int32 (1 ), Int32 (1 ), bid), shape= (D, 2 N ÷ D , BS)), (2 , N, BS))
5847
5948 # Split real and imaginary parts, reshape to 4D factored form
6049 X_r = reshape (ct. extract (X_ri, (1 , 1 , 1 ), (1 , N, BS)), (F2, F1, F0, BS))
@@ -131,7 +120,7 @@ function fft_kernel(
131120 # --- Concatenate and Store ---
132121 X_r_final = reshape (X_r10, (1 , N, BS))
133122 X_i_final = reshape (X_i10, (1 , N, BS))
134- Y_ri = reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, N2D , BS))
123+ Y_ri = reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, 2 N ÷ D , BS))
135124 ct. store (y_packed_out; index= (Int32 (1 ), Int32 (1 ), bid), tile= Y_ri)
136125
137126 return
@@ -193,14 +182,14 @@ end
193182=============================================================================#
194183
195184function prepare (; benchmark:: Bool = false ,
196- batch:: Int = benchmark ? 64 : 2 ,
197- factors:: NTuple{3,Int} = benchmark ? (8 , 8 , 8 ) : (2 , 2 , 2 ),
185+ batch:: Int = benchmark ? 256 : 2 ,
186+ factors:: NTuple{3,Int} = benchmark ? (16 , 16 , 16 ) : (2 , 2 , 2 ),
198187 atom_packing_dim:: Int = min (64 , 2 * prod (factors)))
199- n = prod (factors)
200- @assert (n * 2 ) % atom_packing_dim == 0 " N*2 must be divisible by atom_packing_dim"
188+ N = prod (factors)
189+ @assert 2 N % atom_packing_dim == 0 " 2 * N must be divisible by atom_packing_dim"
201190
202191 cuRAND. seed! (42 )
203- input = cuRAND. randn (ComplexF32, n , batch)
192+ input = cuRAND. randn (ComplexF32, N , batch)
204193
205194 W0, W1, W2, T0, T1 = make_twiddles (factors)
206195 W0_gpu = CuArray (W0)
@@ -210,46 +199,43 @@ function prepare(; benchmark::Bool=false,
210199 T1_gpu = CuArray (T1)
211200
212201 D = atom_packing_dim
213- N2D = n * 2 ÷ D
214- # Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
215- # When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
216- x_ri = reinterpret (reshape, Float32, input) # (2, n, batch)
217- x_packed = D == 2 ? x_ri : reshape (x_ri, D, N2D, batch)
218- y_packed = CuArray {Float32} (undef, D, N2D, batch)
202+ # Pack complex input as (D, 2N ÷ D, batch) Float32 — matches Python's (batch, 2N ÷ D, D) row-major.
203+ # When D=2, reinterpret gives (2, N, batch) directly. For D>2, reshape the flat layout.
204+ x_ri = reinterpret (reshape, Float32, input) # (2, N, batch)
205+ x_packed = D == 2 ? x_ri : reshape (x_ri, D, 2 N ÷ D, batch)
206+ y_packed = CuArray {Float32} (undef, D, 2 N ÷ D, batch)
219207
220208 return (;
221209 input, x_packed, y_packed,
222210 W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
223- factors, batch, n , D, N2D
211+ factors, batch, N , D
224212 )
225213end
226214
227215function run (data; nruns:: Int = 1 , warmup:: Int = 0 )
228216 (; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
229- factors, batch, n , D, N2D ) = data
217+ factors, batch, N , D) = data
230218
231219 F0, F1, F2 = factors
232- F0F1 = F0 * F1
233- F1F2 = F1 * F2
234- F0F2 = F0 * F2
235- grid = (batch, 1 , 1 )
220+ BS = 1
221+ grid = (batch ÷ BS, 1 , 1 )
236222
237223 CUDACore. @sync for _ in 1 : warmup
238- @cuda backend= cuTile blocks= grid fft_kernel (x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct. Constant (n ), ct. Constant (F0), ct. Constant (F1), ct. Constant (F2), ct. Constant (F0F1 ), ct. Constant (F1F2), ct . Constant (F0F2), ct . Constant (batch), ct . Constant (D), ct . Constant (N2D ))
224+ @cuda backend= cuTile blocks= grid fft_kernel (x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct. Constant (N ), ct. Constant (F0), ct. Constant (F1), ct. Constant (F2), ct. Constant (BS ), ct. Constant (D ))
239225 end
240226
241227 times = Float64[]
242228 NVTX. @range " cuTile" begin
243229 for i in 1 : nruns
244230 NVTX. @range " run $i " begin
245- t = CUDACore. @elapsed @cuda backend= cuTile blocks= grid fft_kernel (x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct. Constant (n ), ct. Constant (F0), ct. Constant (F1), ct. Constant (F2), ct. Constant (F0F1 ), ct. Constant (F1F2), ct . Constant (F0F2), ct . Constant (batch), ct . Constant (D), ct . Constant (N2D ))
231+ t = CUDACore. @elapsed @cuda backend= cuTile blocks= grid fft_kernel (x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct. Constant (N ), ct. Constant (F0), ct. Constant (F1), ct. Constant (F2), ct. Constant (BS ), ct. Constant (D ))
246232 push! (times, t * 1000 ) # ms
247233 end
248234 end
249235 end
250236
251- # Unpack output: (D, N2D , batch) → (2, n , batch) → ComplexF32(n, batch)
252- y_ri = D == 2 ? y_packed : reshape (y_packed, 2 , n , batch)
237+ # Unpack output: (D, 2n ÷ D , batch) → (2, N , batch) → ComplexF32(n, batch)
238+ y_ri = D == 2 ? y_packed : reshape (y_packed, 2 , N , batch)
253239 y_complex = reinterpret (reshape, ComplexF32, y_ri)
254240 output = copy (y_complex)
255241
@@ -272,18 +258,19 @@ end
272258=============================================================================#
273259
274260function run_others (data; nruns:: Int = 1 , warmup:: Int = 0 )
275- (; input, batch, n ) = data
261+ (; input, batch, N ) = data
276262 results = Dict {String, Vector{Float64}} ()
277263
264+ plan = cuFFT. plan_fft! (input, 1 )
278265 CUDACore. @sync for _ in 1 : warmup
279- cuFFT . fft! ( copy (input), 1 )
266+ plan * copy (input)
280267 end
281268 times_cufft = Float64[]
282269 NVTX. @range " cuFFT" begin
283270 for i in 1 : nruns
284271 NVTX. @range " run $i " begin
285272 input_copy = copy (input)
286- t = CUDACore. @elapsed cuFFT . fft! (input_copy, 1 )
273+ t = CUDACore. @elapsed plan * input_copy
287274 push! (times_cufft, t * 1000 )
288275 end
289276 end
0 commit comments