Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 44 additions & 45 deletions examples/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ using FFTW
# in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows.
# We use right-multiply X @ W instead of W @ X to process rows instead of columns.
#
# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
# Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving.
# Internally, BS is permuted to trailing position for batched matmul convention.
function fft_kernel(
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, BS, N2D) - natural Julia complex layout
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, BS, N2D)
Expand Down Expand Up @@ -55,96 +56,94 @@ function fft_kernel(
bid = ct.bid(1)

# --- Load Input Data ---
# Input is (D, BS, N2D) where D=2 for real/imag. Load and reshape to (2, BS, N).
X_ri = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
# Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing.
X_ri_mem = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
X_ri = permutedims(X_ri_mem, (1, 3, 2)) # (2, N, BS) — trailing batch

# Split real and imaginary parts (extract from first dimension)
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0))
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0))
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F1F2, F0, BS))
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F1F2, F0, BS))

# --- Load DFT Matrices ---
# W0 (F0 x F0) - for right-multiply X @ W0
# W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing
W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(F0, F0, 2)), (F0, F0, 2))
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))

# W1 (F1 x F1)
W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(F1, F1, 2)), (F1, F1, 2))
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))

# W2 (F2 x F2)
W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(F2, F2, 2)), (F2, F2, 2))
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))

# --- Load Twiddle Factors ---
# T0 (F1F2, F0) - note swapped from Python's (F0, F1F2)
T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(F1F2, F0, 2)), (F1F2, F0, 2))
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N))
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N))
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (N, 1))
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (N, 1))

# T1 (F0F2, F1) - note swapped from Python's (F1, F2)
T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(F0F2, F1, 2)), (F0F2, F1, 2))
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1))
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1))
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (F0F2 * F1, 1))
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (F0F2 * F1, 1))

# --- Stage 0: F0-point DFT ---
# X is (BS, F1F2, F0), W0 is (BS, F0, F0)
# X is (F1F2, F0, BS), W0 is (F0, F0, BS) — trailing batch
# Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements)
# Each row has elements at stride F1F2 in the original array - exactly what we need!
X_r_ = X_r * W0_r - X_i * W0_i # (BS, F1F2, F0) @ (BS, F0, F0) → (BS, F1F2, F0)
X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS)
X_i_ = X_r * W0_i + X_i * W0_r

# --- Twiddle & Permute 0 ---
# Reshape to (BS, N) for element-wise twiddle multiply
X_r_flat = reshape(X_r_, (BS, N))
X_i_flat = reshape(X_i_, (BS, N))
# Reshape to (N, BS) for element-wise twiddle multiply
X_r_flat = reshape(X_r_, (N, BS))
X_i_flat = reshape(X_i_, (N, BS))
X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat

# Reshape and permute for stage 1
# Current logical layout after reshape (BS, F1F2, F0): data at (bs, f1*F2+f2, f0)
# Reshape to (BS, F2, F1, F0) then permute to (BS, F0F2, F1) for stage 1
X_r3 = reshape(X_r2, (BS, F2, F1, F0))
X_i3 = reshape(X_i2, (BS, F2, F1, F0))
X_r4 = permutedims(X_r3, (1, 2, 4, 3)) # (BS, F2, F0, F1)
X_i4 = permutedims(X_i3, (1, 2, 4, 3))
X_r5 = reshape(X_r4, (BS, F0F2, F1))
X_i5 = reshape(X_i4, (BS, F0F2, F1))
# Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1
X_r3 = reshape(X_r2, (F2, F1, F0, BS))
X_i3 = reshape(X_i2, (F2, F1, F0, BS))
X_r4 = permutedims(X_r3, (1, 3, 2, 4)) # (F2, F0, F1, BS)
X_i4 = permutedims(X_i3, (1, 3, 2, 4))
X_r5 = reshape(X_r4, (F0F2, F1, BS))
X_i5 = reshape(X_i4, (F0F2, F1, BS))

# --- Stage 1: F1-point DFT ---
# X is (BS, F0F2, F1), W1 is (BS, F1, F1)
# X is (F0F2, F1, BS), W1 is (F1, F1, BS)
X_r6 = X_r5 * W1_r - X_i5 * W1_i
X_i6 = X_r5 * W1_i + X_i5 * W1_r

# --- Twiddle & Permute 1 ---
X_r_flat2 = reshape(X_r6, (BS, N))
X_i_flat2 = reshape(X_i6, (BS, N))
X_r_flat2 = reshape(X_r6, (N, BS))
X_i_flat2 = reshape(X_i6, (N, BS))
X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2

# Reshape and permute for stage 2
X_r8 = reshape(X_r7, (BS, F2, F0, F1))
X_i8 = reshape(X_i7, (BS, F2, F0, F1))
X_r9 = permutedims(X_r8, (1, 3, 4, 2)) # (BS, F0, F1, F2)
X_i9 = permutedims(X_i8, (1, 3, 4, 2))
X_r10 = reshape(X_r9, (BS, F0F1, F2))
X_i10 = reshape(X_i9, (BS, F0F1, F2))
X_r8 = reshape(X_r7, (F2, F0, F1, BS))
X_i8 = reshape(X_i7, (F2, F0, F1, BS))
X_r9 = permutedims(X_r8, (2, 3, 1, 4)) # (F0, F1, F2, BS)
X_i9 = permutedims(X_i8, (2, 3, 1, 4))
X_r10 = reshape(X_r9, (F0F1, F2, BS))
X_i10 = reshape(X_i9, (F0F1, F2, BS))

# --- Stage 2: F2-point DFT ---
# X is (BS, F0F1, F2), W2 is (BS, F2, F2)
# X is (F0F1, F2, BS), W2 is (F2, F2, BS)
X_r11 = X_r10 * W2_r - X_i10 * W2_i
X_i11 = X_r10 * W2_i + X_i10 * W2_r

# --- Final Output ---
# After stage 2, data is in (BS, F0F1, F2) layout
# Reshape to (BS, F0, F1, F2) - output is already in frequency order
X_r_final = reshape(X_r11, (1, BS, N))
X_i_final = reshape(X_i11, (1, BS, N))
X_r_final = reshape(X_r11, (1, N, BS))
X_i_final = reshape(X_i11, (1, N, BS))

# --- Concatenate and Store ---
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D))
# Permute BS back to middle for memory layout (D, BS, N2D)
Y_ri = permutedims(reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)), (1, 3, 2))
ct.store(y_packed_out; index=(1, bid, 1), tile=Y_ri)

return
Expand Down
141 changes: 127 additions & 14 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -922,32 +922,145 @@ end
=============================================================================#

# Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc
@inline Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} =
# Handles 1D promotion, type promotion, and batched dims (≥3D).
# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC}
_muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)))
end

# 2D × 2D: direct MmaFOp with type promotion
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2})
Intrinsics.mma(a, b, acc)
end

# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N)
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2})
a2d = reshape(a, (size(a, 1), 1))
_muladd(a2d, b, acc, Val(2), Val(2))
end

# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1})
M, K = size(a, 1), size(b, 1)
b2d = reshape(b, (K, 1))
acc2d = reshape(acc, (M, 1))
result = _muladd(a, b2d, acc2d, Val(2), Val(2))
reshape(result, (M,))
end

# Vec-vec (1D × 1D): not supported
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported.")))
end

# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
NB >= 3 || return :(throw(ArgumentError("unreachable")))
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
end
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
NA >= 3 || return :(throw(ArgumentError("unreachable")))
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
end

# Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast
# Julia convention: first two dims are matrix (M,K)/(K,N), trailing dims are batch.
# MmaFOp expects exactly 3D tiles (B, M, K), so we:
# 1. Broadcast batch dims to a common shape
# 2. Permute trailing batch → leading
# 3. Flatten multiple batch dims into one for MmaFOp
# 4. Unflatten + permute back after
@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC},
::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB}
sa = Tuple(SA.parameters)
sb = Tuple(SB.parameters)

# Matrix dims are first two; batch dims are trailing
M = sa[1]; K = sa[2]; N = sb[2]
a_batch = sa[3:end]
b_batch = sb[3:end]

# Broadcast batch dims (pad shorter with trailing 1s, then broadcast)
n_batch = max(length(a_batch), length(b_batch))
a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...)
b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...)
batch_shape = map(max, a_batch_padded, b_batch_padded)
B_flat = prod(batch_shape)

quote
# Reshape + broadcast to align batch dims (still trailing)
a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...)))
b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...)))
acc_bc = broadcast_to(acc, $((M, N, batch_shape...)))
# Flatten batch dims to one (still trailing), then permute to leading
a_3d = permutedims(reshape(a_bc, $((M, K, B_flat))), (3, 1, 2))
b_3d = permutedims(reshape(b_bc, $((K, N, B_flat))), (3, 1, 2))
acc_3d = permutedims(reshape(acc_bc, $((M, N, B_flat))), (3, 1, 2))
# MmaFOp
result_3d = Intrinsics.mma(a_3d, b_3d, acc_3d)
# Permute back to trailing, unflatten batch dims
reshape(permutedims(result_3d, (2, 3, 1)), $((M, N, batch_shape...)))
end
end

# Matrix multiplication (A * B like Julia arrays)
# Matrix multiplication: A * B = muladd(A, B, zeros)
# Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen
@inline function Base.:(*)(a::Tile{T1, SA}, b::Tile{T2, SB}) where {T1, T2, SA, SB}
_matmul(a, b, Val(ndims(a)))
_matmul(a, b, Val(ndims(a)), Val(ndims(b)))
end

# 2D matmul: (M, K) × (K, N) → (M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}) where {T1}
M = size(a, 1)
N = size(b, 2)
acc = zeros(T1, (M, N))
# 2D × 2D → (M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{2}) where {T1}
acc = zeros(T1, (size(a, 1), size(b, 2)))
muladd(a, b, acc)
end

# 3D batched matmul: (B, M, K) × (B, K, N) → (B, M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{3}) where {T1}
B = max(size(a, 1), size(b, 1)) # Broadcast batch dimension
M = size(a, 2)
N = size(b, 3)
acc = zeros(T1, (B, M, N))
# Vec-mat (1D × 2D) → (M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{1}, ::Val{2}) where {T1}
acc = zeros(T1, (size(a, 1), size(b, 2)))
muladd(a, b, acc)
end

# Mat-vec (2D × 1D) → (M,)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{1}) where {T1}
acc = zeros(T1, (size(a, 1),))
muladd(a, b, acc)
end

# Vec-vec (1D × 1D): not supported
@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{1})
return :(throw(ArgumentError("Vector-vector multiplication is not supported. Use dot(a, b) for inner products, or reshape explicitly.")))
end

# Batched (≥3D × ≥3D) → (M, N, batch...)
@generated function _matmul(a::Tile{T1, SA}, b::Tile{T2, SB},
::Val{NA}, ::Val{NB}) where {T1, T2, SA, SB, NA, NB}
sa = Tuple(SA.parameters)
sb = Tuple(SB.parameters)
a_batch = sa[3:end]
b_batch = sb[3:end]
n_batch = max(length(a_batch), length(b_batch))
a_batch_padded = (a_batch..., ntuple(_ -> 1, n_batch - length(a_batch))...)
b_batch_padded = (b_batch..., ntuple(_ -> 1, n_batch - length(b_batch))...)
batch_shape = map(max, a_batch_padded, b_batch_padded)
M = sa[1]; N = sb[2]
out_shape = (M, N, batch_shape...)
quote
acc = zeros(T1, $out_shape)
muladd(a, b, acc)
end
end

# Batched × 1D: not supported — unsqueeze the 1D operand manually
@generated function _matmul(::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
NA >= 3 || return :(throw(ArgumentError("unreachable")))
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
end
@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
NB >= 3 || return :(throw(ArgumentError("unreachable")))
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
end

#=============================================================================
Selection
=============================================================================#
Expand Down
Loading
Loading