From 72c576802732558e46a451f9b6da5be1a4566599 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 20 Mar 2026 17:44:25 +0000 Subject: [PATCH 1/3] switch to trailing batches; allow mat-vec, vec-mat --- src/language/operations.jl | 141 +++++++++++++++-- test/codegen/operations.jl | 76 +++++++++- test/device/atomics.jl | 6 +- test/device/tile.jl | 299 +++++++++++++++++++++++++++++++++++++ 4 files changed, 499 insertions(+), 23 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index 5d788c4a..3f4b51f8 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -922,32 +922,145 @@ end =============================================================================# # Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc -@inline Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} = +# Handles 1D promotion, type promotion, and batched dims (≥3D). +# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen +@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} + _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b))) +end + +# 2D × 2D: direct MmaFOp with type promotion +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}) Intrinsics.mma(a, b, acc) +end + +# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}) + a2d = reshape(a, (size(a, 1), 1)) + _muladd(a2d, b, acc, Val(2), Val(2)) +end + +# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}) + M, K = size(a, 1), size(b, 1) + b2d = reshape(b, (K, 1)) + acc2d = reshape(acc, (M, 1)) + result = _muladd(a, b2d, acc2d, Val(2), Val(2)) + reshape(result, (M,)) +end + +# Vec-vec (1D × 1D): not supported +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) + return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported."))) +end + +# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} + NB >= 3 || return :(throw(ArgumentError("unreachable"))) + return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))) +end +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} + NA >= 3 || return :(throw(ArgumentError("unreachable"))) + return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))) +end + +# Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast +# Julia convention: first two dims are matrix (M,K)/(K,N), trailing dims are batch. +# MmaFOp expects exactly 3D tiles (B, M, K), so we: +# 1. Broadcast batch dims to a common shape +# 2. Permute trailing batch → leading +# 3. Flatten multiple batch dims into one for MmaFOp +# 4. Unflatten + permute back after +@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, + ::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB} + sa = Tuple(SA.parameters) + sb = Tuple(SB.parameters) + + # Matrix dims are first two; batch dims are trailing + M = sa[1]; K = sa[2]; N = sb[2] + a_batch = sa[3:end] + b_batch = sb[3:end] + + # Broadcast batch dims (pad shorter with trailing 1s, then broadcast) + n_batch = max(length(a_batch), length(b_batch)) + a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...) + b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...) + batch_shape = map(max, a_batch_padded, b_batch_padded) + B_flat = prod(batch_shape) + + quote + # Reshape + broadcast to align batch dims (still trailing) + a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) + b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) + acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) + # Flatten batch dims to one (still trailing), then permute to leading + a_3d = permutedims(reshape(a_bc, $((M, K, B_flat))), (3, 1, 2)) + b_3d = permutedims(reshape(b_bc, $((K, N, B_flat))), (3, 1, 2)) + acc_3d = permutedims(reshape(acc_bc, $((M, N, B_flat))), (3, 1, 2)) + # MmaFOp + result_3d = Intrinsics.mma(a_3d, b_3d, acc_3d) + # Permute back to trailing, unflatten batch dims + reshape(permutedims(result_3d, (2, 3, 1)), $((M, N, batch_shape...))) + end +end -# Matrix multiplication (A * B like Julia arrays) +# Matrix multiplication: A * B = muladd(A, B, zeros) # Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen @inline function Base.:(*)(a::Tile{T1, SA}, b::Tile{T2, SB}) where {T1, T2, SA, SB} - _matmul(a, b, Val(ndims(a))) + _matmul(a, b, Val(ndims(a)), Val(ndims(b))) end -# 2D matmul: (M, K) × (K, N) → (M, N) -@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}) where {T1} - M = size(a, 1) - N = size(b, 2) - acc = zeros(T1, (M, N)) +# 2D × 2D → (M, N) +@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{2}) where {T1} + acc = zeros(T1, (size(a, 1), size(b, 2))) muladd(a, b, acc) end -# 3D batched matmul: (B, M, K) × (B, K, N) → (B, M, N) -@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{3}) where {T1} - B = max(size(a, 1), size(b, 1)) # Broadcast batch dimension - M = size(a, 2) - N = size(b, 3) - acc = zeros(T1, (B, M, N)) +# Vec-mat (1D × 2D) → (M, N) +@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{1}, ::Val{2}) where {T1} + acc = zeros(T1, (size(a, 1), size(b, 2))) muladd(a, b, acc) end +# Mat-vec (2D × 1D) → (M,) +@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{1}) where {T1} + acc = zeros(T1, (size(a, 1),)) + muladd(a, b, acc) +end + +# Vec-vec (1D × 1D): not supported +@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{1}) + return :(throw(ArgumentError("Vector-vector multiplication is not supported. Use dot(a, b) for inner products, or reshape explicitly."))) +end + +# Batched (≥3D × ≥3D) → (M, N, batch...) +@generated function _matmul(a::Tile{T1, SA}, b::Tile{T2, SB}, + ::Val{NA}, ::Val{NB}) where {T1, T2, SA, SB, NA, NB} + sa = Tuple(SA.parameters) + sb = Tuple(SB.parameters) + a_batch = sa[3:end] + b_batch = sb[3:end] + n_batch = max(length(a_batch), length(b_batch)) + a_batch_padded = (ntuple(_ -> 1, n_batch - length(a_batch))..., a_batch...) + b_batch_padded = (ntuple(_ -> 1, n_batch - length(b_batch))..., b_batch...) + batch_shape = map(max, a_batch_padded, b_batch_padded) + M = sa[1]; N = sb[2] + out_shape = (M, N, batch_shape...) + quote + acc = zeros(T1, $out_shape) + muladd(a, b, acc) + end +end + +# Batched × 1D: not supported — unsqueeze the 1D operand manually +@generated function _matmul(::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} + NA >= 3 || return :(throw(ArgumentError("unreachable"))) + return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))) +end +@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} + NB >= 3 || return :(throw(ArgumentError("unreachable"))) + return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))) +end + #============================================================================= Selection =============================================================================# diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index 211d4d0c..0a39f8b3 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -2,6 +2,7 @@ spec1d = ct.ArraySpec{1}(16, true) spec2d = ct.ArraySpec{2}(16, true) spec3d = ct.ArraySpec{3}(16, true) +spec4d = ct.ArraySpec{4}(16, true) #========================================================================= 8.3 Core @@ -492,18 +493,81 @@ spec3d = ct.ArraySpec{3}(16, true) end end - @testset "matmul" begin + @testset "vec-mat outer product" begin @test @filecheck begin @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b, c + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b, c + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16,)) + tile_b = ct.load(b, bidx, (1, 16)) + # vec-mat: reshape + mma + @check "reshape" + @check "mma" + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + end + end + + @testset "mat-vec" begin + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b, c bidx = ct.bid(1) - bidy = ct.bid(2) tile_a = ct.load(a, bidx, (32, 16)) - tile_b = ct.load(b, bidy, (16, 32)) - # matmul via * operator = mma with zero accumulator + tile_b = ct.load(b, bidx, (16,)) + # mat-vec: reshape + mma + reshape + @check "reshape" @check "mma" result = tile_a * tile_b - ct.store(c, (bidx, bidy), result) + ct.store(c, bidx, result) + return + end + end + end + + @testset "batched matmul with trailing batch dims" begin + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,3,spec3d}, ct.TileArray{Float32,3,spec3d}, ct.TileArray{Float32,3,spec3d}}) do a, b, c + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16, 1)) + tile_b = ct.load(b, bidx, (16, 32, 4)) + # batched: broadcast + permute + mma + permute + @check "broadcast" + @check "permute" + @check "mma" + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + end + end + + @testset "vec-vec throws error" begin + @test_throws cuTile.IRError begin + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16,)) + tile_b = ct.load(b, bidx, (16,)) + tile_a * tile_b + return + end + end + end + + @testset "4D batched matmul (2 batch dims)" begin + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,4,spec4d}, ct.TileArray{Float32,4,spec4d}, ct.TileArray{Float32,4,spec4d}}) do a, b, c + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16, 8, 2, 4)) + tile_b = ct.load(b, bidx, (8, 16, 2, 4)) + @check "permute" + @check "mma" + result = tile_a * tile_b + ct.store(c, bidx, result) return end end diff --git a/test/device/atomics.jl b/test/device/atomics.jl index 34df6c12..9feb7544 100644 --- a/test/device/atomics.jl +++ b/test/device/atomics.jl @@ -319,9 +319,9 @@ end @testset "atomic_add tile-indexed 3D" begin function atomic_add_3d_kernel(arr::ct.TileArray{Int,3}) # 3D index tiles — each is length 4, will broadcast to (4,4,4) = 64 elements - i = ct.reshape(ct.arange(4; dtype=Int), (4, 1, 1)) - j = ct.reshape(ct.arange(4; dtype=Int), (1, 4, 1)) - k = ct.reshape(ct.arange(4; dtype=Int), (1, 1, 4)) + i = reshape(ct.arange(4; dtype=Int), (4, 1, 1)) + j = reshape(ct.arange(4; dtype=Int), (1, 4, 1)) + k = reshape(ct.arange(4; dtype=Int), (1, 1, 4)) ct.atomic_add(arr, (i, j, k), 1; memory_order=ct.MemoryOrder.AcqRel) return diff --git a/test/device/tile.jl b/test/device/tile.jl index fd9ca54c..e03db211 100644 --- a/test/device/tile.jl +++ b/test/device/tile.jl @@ -867,3 +867,302 @@ end end end end + +@testset "matmul" begin + @testset "basic matmul" begin + function matmul_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,2}, + c::ct.TileArray{Float32,2}) + bidx = ct.bid(1) + bidy = ct.bid(2) + # Load tiles: a is (M, K), b is (K, N) + tile_a = ct.load(a, (bidx, 1), (32, 16)) + tile_b = ct.load(b, (1, bidy), (16, 32)) + # matmul: c = a @ b (using * operator) + result = tile_a * tile_b + ct.store(c, (bidx, bidy), result) + return + end + + M, K, N = 64, 16, 64 + a = CUDA.rand(Float32, M, K) + b = CUDA.rand(Float32, K, N) + c = CUDA.zeros(Float32, M, N) + + grid_x = cld(M, 32) + grid_y = cld(N, 32) + ct.launch(matmul_kernel, (grid_x, grid_y, 1), a, b, c) + + # Verify against CPU reference + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = a_cpu * b_cpu + + @test c_cpu ≈ c_ref + end + + @testset "vec-mat outer product" begin + function outer_product_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,2}, + c::ct.TileArray{Float32,2}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16,)) + tile_b = ct.load(b, bidx, (1, 16)) + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + + M, N = 16, 16 + a = CUDA.rand(Float32, M) + b = CUDA.rand(Float32, 1, N) + c = CUDA.zeros(Float32, M, N) + + ct.launch(outer_product_kernel, 1, a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + # Julia: (M,) * (1, N) = (M, 1) * (1, N) = (M, N) + c_ref = reshape(a_cpu, :, 1) * b_cpu + + @test c_cpu ≈ c_ref + end + + @testset "mat-vec" begin + function matvec_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16)) + tile_b = ct.load(b, bidx, (16,)) + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + + M, K = 32, 16 + a = CUDA.rand(Float32, M, K) + b = CUDA.rand(Float32, K) + c = CUDA.zeros(Float32, M) + + ct.launch(matvec_kernel, 1, a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = a_cpu * b_cpu + + @test c_cpu ≈ c_ref + end + + @testset "batched matmul with trailing batch broadcast" begin + function batched_matmul_kernel(a::ct.TileArray{Float32,3}, b::ct.TileArray{Float32,3}, + c::ct.TileArray{Float32,3}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16, 1)) + tile_b = ct.load(b, bidx, (16, 32, 4)) + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + + M, K, N, B = 32, 16, 32, 4 + a = CUDA.rand(Float32, M, K, 1) + b = CUDA.rand(Float32, K, N, B) + c = CUDA.zeros(Float32, M, N, B) + + ct.launch(batched_matmul_kernel, 1, a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + # Reference: broadcast batch dim and matmul per batch + a_bc = repeat(a_cpu, 1, 1, B) + c_ref = similar(c_cpu) + for i in 1:B + c_ref[:, :, i] = a_bc[:, :, i] * b_cpu[:, :, i] + end + + @test c_cpu ≈ c_ref + end + + @testset "multi-block vec-mat" begin + function multi_vecmat_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,2}, + c::ct.TileArray{Float32,2}) + bidx = ct.bid(1) + bidy = ct.bid(2) + tile_a = ct.load(a, bidx, (16,)) + tile_b = ct.load(b, (1, bidy), (1, 32)) + result = tile_a * tile_b + ct.store(c, (bidx, bidy), result) + return + end + + M, N = 64, 128 + a = CUDA.rand(Float32, M) + b = CUDA.rand(Float32, 1, N) + c = CUDA.zeros(Float32, M, N) + + # Each block handles a 16-row chunk of a and a 32-col chunk of b + ct.launch(multi_vecmat_kernel, (cld(M, 16), cld(N, 32)), a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = reshape(a_cpu, :, 1) * b_cpu + + @test c_cpu ≈ c_ref + end + + @testset "multi-block mat-vec" begin + function multi_matvec_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16)) + tile_b = ct.load(b, 1, (16,)) + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + + M, K = 128, 16 + a = CUDA.rand(Float32, M, K) + b = CUDA.rand(Float32, K) + c = CUDA.zeros(Float32, M) + + ct.launch(multi_matvec_kernel, cld(M, 32), a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = a_cpu * b_cpu + + @test c_cpu ≈ c_ref + end + + @testset "batched mat-vec (3D × 1D) errors" begin + @test_throws cuTile.IRError begin + ct.code_tiled(Tuple{ct.TileArray{Float32,3,ct.ArraySpec{3}(16,true)}, + ct.TileArray{Float32,1,ct.ArraySpec{1}(16,true)}}) do a, b + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16, 8, 4)) + tile_b = ct.load(b, bidx, (8,)) + tile_a * tile_b + return + end + end + end + + @testset "4D batched matmul (2 batch dims)" begin + function batched_4d_kernel(a::ct.TileArray{Float32,4}, b::ct.TileArray{Float32,4}, + c::ct.TileArray{Float32,4}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (16, 8, 2, 4)) + tile_b = ct.load(b, bidx, (8, 16, 2, 4)) + result = tile_a * tile_b + ct.store(c, bidx, result) + return + end + + M, K, N, B1, B2 = 16, 8, 16, 2, 4 + a = CUDA.rand(Float32, M, K, B1, B2) + b = CUDA.rand(Float32, K, N, B1, B2) + c = CUDA.zeros(Float32, M, N, B1, B2) + + ct.launch(batched_4d_kernel, 1, a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = similar(c_cpu) + for j in 1:B2, i in 1:B1 + c_ref[:, :, i, j] = a_cpu[:, :, i, j] * b_cpu[:, :, i, j] + end + + @test c_cpu ≈ c_ref + end + + @testset "mixed dtype matmul (Float16 × Float32 with explicit convert)" begin + function mixed_matmul_kernel(a::ct.TileArray{Float16,2}, b::ct.TileArray{Float32,2}, + c::ct.TileArray{Float32,2}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16)) + tile_b = ct.load(b, bidx, (16, 32)) + # MmaFOp requires matching input types; caller must convert + result = convert(ct.Tile{Float32}, tile_a) * tile_b + ct.store(c, bidx, result) + return + end + + M, K, N = 32, 16, 32 + a = CUDA.rand(Float16, M, K) + b = CUDA.rand(Float32, K, N) + c = CUDA.zeros(Float32, M, N) + + ct.launch(mixed_matmul_kernel, 1, a, b, c) + + a_cpu = Array(a) + b_cpu = Array(b) + c_cpu = Array(c) + c_ref = Float32.(a_cpu) * b_cpu + + @test c_cpu ≈ c_ref rtol=1e-2 + end + + @testset "muladd mat-vec accumulated" begin + function matvec_muladd_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}, + bias::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16)) + tile_b = ct.load(b, 1, (16,)) + tile_bias = ct.load(bias, bidx, (32,)) + result = muladd(tile_a, tile_b, tile_bias) + ct.store(c, bidx, result) + return + end + + M, K = 64, 16 + a = CUDA.rand(Float32, M, K) + b = CUDA.rand(Float32, K) + bias = CUDA.rand(Float32, M) + c = CUDA.zeros(Float32, M) + + ct.launch(matvec_muladd_kernel, cld(M, 32), a, b, bias, c) + + a_cpu = Array(a) + b_cpu = Array(b) + bias_cpu = Array(bias) + c_cpu = Array(c) + c_ref = a_cpu * b_cpu + bias_cpu + + @test c_cpu ≈ c_ref + end + + @testset "muladd matmul accumulated" begin + function matmul_muladd_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,2}, + bias::ct.TileArray{Float32,2}, c::ct.TileArray{Float32,2}) + bidx = ct.bid(1) + tile_a = ct.load(a, bidx, (32, 16)) + tile_b = ct.load(b, bidx, (16, 32)) + tile_bias = ct.load(bias, bidx, (32, 32)) + result = muladd(tile_a, tile_b, tile_bias) + ct.store(c, bidx, result) + return + end + + M, K, N = 32, 16, 32 + a = CUDA.rand(Float32, M, K) + b = CUDA.rand(Float32, K, N) + bias = CUDA.rand(Float32, M, N) + c = CUDA.zeros(Float32, M, N) + + ct.launch(matmul_muladd_kernel, 1, a, b, bias, c) + + a_cpu = Array(a) + b_cpu = Array(b) + bias_cpu = Array(bias) + c_cpu = Array(c) + c_ref = a_cpu * b_cpu + bias_cpu + + @test c_cpu ≈ c_ref + end +end From a0049f97625f30aaf8524f2257e2fb4dd0756262 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 09:01:20 +0100 Subject: [PATCH 2/3] Update FFT example. --- examples/fft.jl | 89 ++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/examples/fft.jl b/examples/fft.jl index eadd3be0..c4159d28 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -20,7 +20,8 @@ using FFTW # in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows. # We use right-multiply X @ W instead of W @ X to process rows instead of columns. # -# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving. +# Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving. +# Internally, BS is permuted to trailing position for batched matmul convention. function fft_kernel( x_packed_in::ct.TileArray{Float32, 3}, # Input (D, BS, N2D) - natural Julia complex layout y_packed_out::ct.TileArray{Float32, 3}, # Output (D, BS, N2D) @@ -55,96 +56,94 @@ function fft_kernel( bid = ct.bid(1) # --- Load Input Data --- - # Input is (D, BS, N2D) where D=2 for real/imag. Load and reshape to (2, BS, N). - X_ri = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N)) + # Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing. + X_ri_mem = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N)) + X_ri = permutedims(X_ri_mem, (1, 3, 2)) # (2, N, BS) — trailing batch # Split real and imaginary parts (extract from first dimension) - X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0)) - X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0)) + X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F1F2, F0, BS)) + X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F1F2, F0, BS)) # --- Load DFT Matrices --- - # W0 (F0 x F0) - for right-multiply X @ W0 + # W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) - W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) + W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS)) + W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS)) # W1 (F1 x F1) W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) - W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) + W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS)) + W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS)) # W2 (F2 x F2) W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) - W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) + W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS)) + W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS)) # --- Load Twiddle Factors --- # T0 (F1F2, F0) - note swapped from Python's (F0, F1F2) T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(F1F2, F0, 2)), (F1F2, F0, 2)) - T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N)) - T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N)) + T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (N, 1)) + T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (N, 1)) # T1 (F0F2, F1) - note swapped from Python's (F1, F2) T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(F0F2, F1, 2)), (F0F2, F1, 2)) - T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1)) - T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1)) + T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (F0F2 * F1, 1)) + T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (F0F2 * F1, 1)) # --- Stage 0: F0-point DFT --- - # X is (BS, F1F2, F0), W0 is (BS, F0, F0) + # X is (F1F2, F0, BS), W0 is (F0, F0, BS) — trailing batch # Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements) - # Each row has elements at stride F1F2 in the original array - exactly what we need! - X_r_ = X_r * W0_r - X_i * W0_i # (BS, F1F2, F0) @ (BS, F0, F0) → (BS, F1F2, F0) + X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS) X_i_ = X_r * W0_i + X_i * W0_r # --- Twiddle & Permute 0 --- - # Reshape to (BS, N) for element-wise twiddle multiply - X_r_flat = reshape(X_r_, (BS, N)) - X_i_flat = reshape(X_i_, (BS, N)) + # Reshape to (N, BS) for element-wise twiddle multiply + X_r_flat = reshape(X_r_, (N, BS)) + X_i_flat = reshape(X_i_, (N, BS)) X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat # Reshape and permute for stage 1 - # Current logical layout after reshape (BS, F1F2, F0): data at (bs, f1*F2+f2, f0) - # Reshape to (BS, F2, F1, F0) then permute to (BS, F0F2, F1) for stage 1 - X_r3 = reshape(X_r2, (BS, F2, F1, F0)) - X_i3 = reshape(X_i2, (BS, F2, F1, F0)) - X_r4 = permutedims(X_r3, (1, 2, 4, 3)) # (BS, F2, F0, F1) - X_i4 = permutedims(X_i3, (1, 2, 4, 3)) - X_r5 = reshape(X_r4, (BS, F0F2, F1)) - X_i5 = reshape(X_i4, (BS, F0F2, F1)) + # Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1 + X_r3 = reshape(X_r2, (F2, F1, F0, BS)) + X_i3 = reshape(X_i2, (F2, F1, F0, BS)) + X_r4 = permutedims(X_r3, (1, 3, 2, 4)) # (F2, F0, F1, BS) + X_i4 = permutedims(X_i3, (1, 3, 2, 4)) + X_r5 = reshape(X_r4, (F0F2, F1, BS)) + X_i5 = reshape(X_i4, (F0F2, F1, BS)) # --- Stage 1: F1-point DFT --- - # X is (BS, F0F2, F1), W1 is (BS, F1, F1) + # X is (F0F2, F1, BS), W1 is (F1, F1, BS) X_r6 = X_r5 * W1_r - X_i5 * W1_i X_i6 = X_r5 * W1_i + X_i5 * W1_r # --- Twiddle & Permute 1 --- - X_r_flat2 = reshape(X_r6, (BS, N)) - X_i_flat2 = reshape(X_i6, (BS, N)) + X_r_flat2 = reshape(X_r6, (N, BS)) + X_i_flat2 = reshape(X_i6, (N, BS)) X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2 X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2 # Reshape and permute for stage 2 - X_r8 = reshape(X_r7, (BS, F2, F0, F1)) - X_i8 = reshape(X_i7, (BS, F2, F0, F1)) - X_r9 = permutedims(X_r8, (1, 3, 4, 2)) # (BS, F0, F1, F2) - X_i9 = permutedims(X_i8, (1, 3, 4, 2)) - X_r10 = reshape(X_r9, (BS, F0F1, F2)) - X_i10 = reshape(X_i9, (BS, F0F1, F2)) + X_r8 = reshape(X_r7, (F2, F0, F1, BS)) + X_i8 = reshape(X_i7, (F2, F0, F1, BS)) + X_r9 = permutedims(X_r8, (2, 3, 1, 4)) # (F0, F1, F2, BS) + X_i9 = permutedims(X_i8, (2, 3, 1, 4)) + X_r10 = reshape(X_r9, (F0F1, F2, BS)) + X_i10 = reshape(X_i9, (F0F1, F2, BS)) # --- Stage 2: F2-point DFT --- - # X is (BS, F0F1, F2), W2 is (BS, F2, F2) + # X is (F0F1, F2, BS), W2 is (F2, F2, BS) X_r11 = X_r10 * W2_r - X_i10 * W2_i X_i11 = X_r10 * W2_i + X_i10 * W2_r # --- Final Output --- - # After stage 2, data is in (BS, F0F1, F2) layout - # Reshape to (BS, F0, F1, F2) - output is already in frequency order - X_r_final = reshape(X_r11, (1, BS, N)) - X_i_final = reshape(X_i11, (1, BS, N)) + X_r_final = reshape(X_r11, (1, N, BS)) + X_i_final = reshape(X_i11, (1, N, BS)) # --- Concatenate and Store --- - Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D)) + # Permute BS back to middle for memory layout (D, BS, N2D) + Y_ri = permutedims(reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)), (1, 3, 2)) ct.store(y_packed_out; index=(1, bid, 1), tile=Y_ri) return From d4fc79465b6020f7fa099c32d75db953d5937af0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 09:09:21 +0100 Subject: [PATCH 3/3] Fix batch dim padding mismatch between _matmul and _muladd When operands have different numbers of batch dimensions (e.g. (M, K, 4) * (K, N, 2, 4)), _matmul pads the shorter batch tuple with ones to align them before computing the output shape and creating the zero accumulator. _muladd does the same padding to reshape operands before broadcasting. These two functions disagreed on *where* to pad: _matmul inserted leading ones ((1, 4) for a 1-batch operand against a 2-batch one) while _muladd appended trailing ones ((4, 1)). This meant the acc shape from _matmul wouldn't match what _muladd expected, causing a reshape element-count mismatch at the Tile IR level. Fix _matmul to use trailing ones, consistent with _muladd. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/language/operations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index 3f4b51f8..0d83d14f 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1040,8 +1040,8 @@ end a_batch = sa[3:end] b_batch = sb[3:end] n_batch = max(length(a_batch), length(b_batch)) - a_batch_padded = (ntuple(_ -> 1, n_batch - length(a_batch))..., a_batch...) - b_batch_padded = (ntuple(_ -> 1, n_batch - length(b_batch))..., b_batch...) + a_batch_padded = (a_batch..., ntuple(_ -> 1, n_batch - length(a_batch))...) + b_batch_padded = (b_batch..., ntuple(_ -> 1, n_batch - length(b_batch))...) batch_shape = map(max, a_batch_padded, b_batch_padded) M = sa[1]; N = sb[2] out_shape = (M, N, batch_shape...)