Skip to content

Commit a0049f9

Browse files
committed
Update FFT example.
1 parent 72c5768 commit a0049f9

1 file changed

Lines changed: 44 additions & 45 deletions

File tree

examples/fft.jl

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using FFTW
2020
# in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows.
2121
# We use right-multiply X @ W instead of W @ X to process rows instead of columns.
2222
#
23-
# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
23+
# Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving.
24+
# Internally, BS is permuted to trailing position for batched matmul convention.
2425
function fft_kernel(
2526
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, BS, N2D) - natural Julia complex layout
2627
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, BS, N2D)
@@ -55,96 +56,94 @@ function fft_kernel(
5556
bid = ct.bid(1)
5657

5758
# --- Load Input Data ---
58-
# Input is (D, BS, N2D) where D=2 for real/imag. Load and reshape to (2, BS, N).
59-
X_ri = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
59+
# Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing.
60+
X_ri_mem = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
61+
X_ri = permutedims(X_ri_mem, (1, 3, 2)) # (2, N, BS) — trailing batch
6062

6163
# Split real and imaginary parts (extract from first dimension)
62-
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0))
63-
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0))
64+
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F1F2, F0, BS))
65+
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F1F2, F0, BS))
6466

6567
# --- Load DFT Matrices ---
66-
# W0 (F0 x F0) - for right-multiply X @ W0
68+
# W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing
6769
W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(F0, F0, 2)), (F0, F0, 2))
68-
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
69-
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
70+
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))
71+
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))
7072

7173
# W1 (F1 x F1)
7274
W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(F1, F1, 2)), (F1, F1, 2))
73-
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
74-
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
75+
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))
76+
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))
7577

7678
# W2 (F2 x F2)
7779
W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(F2, F2, 2)), (F2, F2, 2))
78-
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
79-
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
80+
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))
81+
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))
8082

8183
# --- Load Twiddle Factors ---
8284
# T0 (F1F2, F0) - note swapped from Python's (F0, F1F2)
8385
T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(F1F2, F0, 2)), (F1F2, F0, 2))
84-
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N))
85-
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N))
86+
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (N, 1))
87+
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (N, 1))
8688

8789
# T1 (F0F2, F1) - note swapped from Python's (F1, F2)
8890
T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(F0F2, F1, 2)), (F0F2, F1, 2))
89-
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1))
90-
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1))
91+
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (F0F2 * F1, 1))
92+
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (F0F2 * F1, 1))
9193

9294
# --- Stage 0: F0-point DFT ---
93-
# X is (BS, F1F2, F0), W0 is (BS, F0, F0)
95+
# X is (F1F2, F0, BS), W0 is (F0, F0, BS) — trailing batch
9496
# Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements)
95-
# Each row has elements at stride F1F2 in the original array - exactly what we need!
96-
X_r_ = X_r * W0_r - X_i * W0_i # (BS, F1F2, F0) @ (BS, F0, F0) → (BS, F1F2, F0)
97+
X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS)
9798
X_i_ = X_r * W0_i + X_i * W0_r
9899

99100
# --- Twiddle & Permute 0 ---
100-
# Reshape to (BS, N) for element-wise twiddle multiply
101-
X_r_flat = reshape(X_r_, (BS, N))
102-
X_i_flat = reshape(X_i_, (BS, N))
101+
# Reshape to (N, BS) for element-wise twiddle multiply
102+
X_r_flat = reshape(X_r_, (N, BS))
103+
X_i_flat = reshape(X_i_, (N, BS))
103104
X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
104105
X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat
105106

106107
# Reshape and permute for stage 1
107-
# Current logical layout after reshape (BS, F1F2, F0): data at (bs, f1*F2+f2, f0)
108-
# Reshape to (BS, F2, F1, F0) then permute to (BS, F0F2, F1) for stage 1
109-
X_r3 = reshape(X_r2, (BS, F2, F1, F0))
110-
X_i3 = reshape(X_i2, (BS, F2, F1, F0))
111-
X_r4 = permutedims(X_r3, (1, 2, 4, 3)) # (BS, F2, F0, F1)
112-
X_i4 = permutedims(X_i3, (1, 2, 4, 3))
113-
X_r5 = reshape(X_r4, (BS, F0F2, F1))
114-
X_i5 = reshape(X_i4, (BS, F0F2, F1))
108+
# Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1
109+
X_r3 = reshape(X_r2, (F2, F1, F0, BS))
110+
X_i3 = reshape(X_i2, (F2, F1, F0, BS))
111+
X_r4 = permutedims(X_r3, (1, 3, 2, 4)) # (F2, F0, F1, BS)
112+
X_i4 = permutedims(X_i3, (1, 3, 2, 4))
113+
X_r5 = reshape(X_r4, (F0F2, F1, BS))
114+
X_i5 = reshape(X_i4, (F0F2, F1, BS))
115115

116116
# --- Stage 1: F1-point DFT ---
117-
# X is (BS, F0F2, F1), W1 is (BS, F1, F1)
117+
# X is (F0F2, F1, BS), W1 is (F1, F1, BS)
118118
X_r6 = X_r5 * W1_r - X_i5 * W1_i
119119
X_i6 = X_r5 * W1_i + X_i5 * W1_r
120120

121121
# --- Twiddle & Permute 1 ---
122-
X_r_flat2 = reshape(X_r6, (BS, N))
123-
X_i_flat2 = reshape(X_i6, (BS, N))
122+
X_r_flat2 = reshape(X_r6, (N, BS))
123+
X_i_flat2 = reshape(X_i6, (N, BS))
124124
X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
125125
X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2
126126

127127
# Reshape and permute for stage 2
128-
X_r8 = reshape(X_r7, (BS, F2, F0, F1))
129-
X_i8 = reshape(X_i7, (BS, F2, F0, F1))
130-
X_r9 = permutedims(X_r8, (1, 3, 4, 2)) # (BS, F0, F1, F2)
131-
X_i9 = permutedims(X_i8, (1, 3, 4, 2))
132-
X_r10 = reshape(X_r9, (BS, F0F1, F2))
133-
X_i10 = reshape(X_i9, (BS, F0F1, F2))
128+
X_r8 = reshape(X_r7, (F2, F0, F1, BS))
129+
X_i8 = reshape(X_i7, (F2, F0, F1, BS))
130+
X_r9 = permutedims(X_r8, (2, 3, 1, 4)) # (F0, F1, F2, BS)
131+
X_i9 = permutedims(X_i8, (2, 3, 1, 4))
132+
X_r10 = reshape(X_r9, (F0F1, F2, BS))
133+
X_i10 = reshape(X_i9, (F0F1, F2, BS))
134134

135135
# --- Stage 2: F2-point DFT ---
136-
# X is (BS, F0F1, F2), W2 is (BS, F2, F2)
136+
# X is (F0F1, F2, BS), W2 is (F2, F2, BS)
137137
X_r11 = X_r10 * W2_r - X_i10 * W2_i
138138
X_i11 = X_r10 * W2_i + X_i10 * W2_r
139139

140140
# --- Final Output ---
141-
# After stage 2, data is in (BS, F0F1, F2) layout
142-
# Reshape to (BS, F0, F1, F2) - output is already in frequency order
143-
X_r_final = reshape(X_r11, (1, BS, N))
144-
X_i_final = reshape(X_i11, (1, BS, N))
141+
X_r_final = reshape(X_r11, (1, N, BS))
142+
X_i_final = reshape(X_i11, (1, N, BS))
145143

146144
# --- Concatenate and Store ---
147-
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D))
145+
# Permute BS back to middle for memory layout (D, BS, N2D)
146+
Y_ri = permutedims(reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)), (1, 3, 2))
148147
ct.store(y_packed_out; index=(1, bid, 1), tile=Y_ri)
149148

150149
return

0 commit comments

Comments
 (0)