@@ -52,7 +52,7 @@ function fft_kernel(
5252
5353 # --- Load Input Data ---
5454 # Input is (D, N2D, BS). Load and reshape to (2, N, BS).
55- X_ri = reshape (ct. load (x_packed_in; index= (1 , 1 , bid), shape= (D, N2D, BS)), (2 , N, BS))
55+ X_ri = reshape (ct. load (x_packed_in; index= (Int32 ( 1 ), Int32 ( 1 ) , bid), shape= (D, N2D, BS)), (2 , N, BS))
5656
5757 # Split real and imaginary parts, reshape to 4D factored form
5858 X_r = reshape (ct. extract (X_ri, (1 , 1 , 1 ), (1 , N, BS)), (F2, F1, F0, BS))
@@ -130,7 +130,7 @@ function fft_kernel(
130130 X_r_final = reshape (X_r10, (1 , N, BS))
131131 X_i_final = reshape (X_i10, (1 , N, BS))
132132 Y_ri = reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, N2D, BS))
133- ct. store (y_packed_out; index= (1 , 1 , bid), tile= Y_ri)
133+ ct. store (y_packed_out; index= (Int32 ( 1 ), Int32 ( 1 ) , bid), tile= Y_ri)
134134
135135 return
136136end
@@ -192,15 +192,12 @@ end
192192
193193function prepare (; benchmark:: Bool = false ,
194194 batch:: Int = benchmark ? 64 : 2 ,
195- n:: Int = benchmark ? 512 : 8 ,
196195 factors:: NTuple{3,Int} = benchmark ? (8 , 8 , 8 ) : (2 , 2 , 2 ),
197- atom_packing_dim:: Int = 2 )
198- @assert factors[ 1 ] * factors[ 2 ] * factors[ 3 ] == n " Factors must multiply to N "
196+ atom_packing_dim:: Int = min ( 64 , 2 * prod (factors)) )
197+ n = prod (factors)
199198 @assert (n * 2 ) % atom_packing_dim == 0 " N*2 must be divisible by atom_packing_dim"
200199
201200 CUDA. seed! (42 )
202- # Store as (n, batch) so reinterpret gives (2, n, batch) = (D, N2D, batch)
203- # This matches Python's (batch, N2D, D) row-major in memory.
204201 input = CUDA. randn (ComplexF32, n, batch)
205202
206203 W0, W1, W2, T0, T1 = make_twiddles (factors)
@@ -212,7 +209,10 @@ function prepare(; benchmark::Bool=false,
212209
213210 D = atom_packing_dim
214211 N2D = n * 2 ÷ D
215- x_packed = reinterpret (reshape, Float32, input) # (2, n, batch) = (D, N2D, batch)
212+ # Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
213+ # When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
214+ x_ri = reinterpret (reshape, Float32, input) # (2, n, batch)
215+ x_packed = D == 2 ? x_ri : reshape (x_ri, D, N2D, batch)
216216 y_packed = CuArray {Float32} (undef, D, N2D, batch)
217217
218218 return (;
@@ -252,8 +252,9 @@ function run(data; nruns::Int=1, warmup::Int=0)
252252 push! (times, t * 1000 ) # ms
253253 end
254254
255- # Unpack output: (2, n, batch) → ComplexF32(n, batch)
256- y_complex = reinterpret (reshape, ComplexF32, y_packed)
255+ # Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch)
256+ y_ri = D == 2 ? y_packed : reshape (y_packed, 2 , n, batch)
257+ y_complex = reinterpret (reshape, ComplexF32, y_ri)
257258 output = copy (y_complex)
258259
259260 return (; output, times)
@@ -294,18 +295,12 @@ end
294295function main ()
295296 println (" --- Running cuTile FFT Example ---" )
296297
297- BATCH_SIZE = 2
298- FFT_SIZE = 8
299- FFT_FACTORS = (2 , 2 , 2 )
300- ATOM_PACKING_DIM = 2
301-
298+ data = prepare ()
302299 println (" Configuration:" )
303- println (" FFT Size (N): $FFT_SIZE " )
304- println (" Batch Size: $BATCH_SIZE " )
305- println (" FFT Factors: $FFT_FACTORS " )
306- println (" Atom Packing Dim: $ATOM_PACKING_DIM " )
307-
308- data = prepare (; batch= BATCH_SIZE, n= FFT_SIZE, factors= FFT_FACTORS, atom_packing_dim= ATOM_PACKING_DIM)
300+ println (" FFT Size (N): $(data. n) " )
301+ println (" Batch Size: $(data. batch) " )
302+ println (" FFT Factors: $(data. factors) " )
303+ println (" Atom Packing Dim: $(data. D) " )
309304 println (" \n Input data shape: $(size (data. input)) , dtype: $(eltype (data. input)) " )
310305
311306 result = run (data)
0 commit comments