Skip to content

Commit a873277

Browse files
maleadtclaude
andauthored
Add pow2 strength reduction (#177)
Replace expensive transcendental pow(x, 2.0) with a simple multiply. This directly benefits layernorm's variance computation where `centered_tx .^ 2.0f0` was generating a tile-wide pow instruction. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 390fdee commit a873277

5 files changed

Lines changed: 74 additions & 21 deletions

File tree

README.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,23 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras`
9797

9898
| Kernel | Size | Julia | Python | Status |
9999
|--------|------|-------|--------|--------|
100-
| Vector Addition | 2^27 f32 | 842 GB/s | 844 GB/s | OK (=) |
101-
| Matrix Transpose | 8192² f32 | 801 GB/s | 810 GB/s | OK (-1%) |
102-
| Layer Normalization | 4096² f32 fwd | 687 GB/s | 720 GB/s | -5% |
103-
| Matrix Multiplication | 4096³ f32 | 47.2 TFLOPS | 43.3 TFLOPS | +9%* |
104-
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 33.5 TFLOPS | 30.8 TFLOPS | +9%* |
105-
| FFT (3-stage Cooley-Tukey) | 1024-pt ×64 c64 | 3264 μs | 3133 μs | -4% |
106-
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 19.5 TFLOPS | 20.1 TFLOPS | -3% |
107-
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 87.8 TFLOPS | 61.3 TFLOPS | +43%** |
108-
109-
\* Likely because Julia's `for` loop guards give `tileiras` a guarantee that the
100+
| Vector Addition | 2^27 f32 | 841 GB/s | 845 GB/s | OK (=) |
101+
| Matrix Transpose | 8192² f32 | 805 GB/s | 811 GB/s | OK (-1%) |
102+
| Layer Norm fwd | 4096² f32 | 925 GB/s | 722 GB/s | +28%* |
103+
| Layer Norm bwd | 4096² f32 | 243 GB/s | 251 GB/s | -3% |
104+
| Matrix Multiplication | 4096³ f32 | 46.9 TFLOPS | 43.4 TFLOPS | +8%** |
105+
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 33.6 TFLOPS | 30.9 TFLOPS | +9%** |
106+
| FFT (3-stage Cooley-Tukey) | 1024-pt ×64 c64 | 3263 μs | 3127 μs | -4% |
107+
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 19.3 TFLOPS | 20.3 TFLOPS | -5% |
108+
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 88.5 TFLOPS | 61.6 TFLOPS | +44%*** |
109+
110+
\* The pow(x, 2) → mulf(x, x) strength reduction eliminates the expensive
111+
transcendental in the variance computation. Python still emits `pow`.
112+
113+
\*\* Likely because Julia's `for` loop guards give `tileiras` a guarantee that the
110114
loop body executes at least once, enabling more aggressive warp scheduling.
111115

112-
\*\* Likely due to Python's compiler splitting the causal masking loop into two
116+
\*\*\* Likely due to Python's compiler splitting the causal masking loop into two
113117
loops, duplicating the loop body. Julia emits a single loop with a conditional.
114118

115119

examples/benchmarks.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ function run_benchmark(name::String)
9191
data = @invokelatest mod.prepare(; benchmark=true)
9292

9393
# Get metric info if available
94-
metric_total, metric_unit = 0, ""
94+
# metric() returns either (total, unit) or Dict("impl" => (total, unit))
95+
metric_result = nothing
9596
if isdefined(mod, :metric)
96-
metric_total, metric_unit = @invokelatest mod.metric(data)
97+
metric_result = @invokelatest mod.metric(data)
9798
end
9899

99100
# Run cuTile
@@ -117,7 +118,7 @@ function run_benchmark(name::String)
117118
merge!(results, others)
118119
end
119120

120-
return results, metric_total, metric_unit
121+
return results, metric_result
121122
end
122123

123124
#=============================================================================
@@ -142,14 +143,23 @@ function main()
142143
continue
143144
end
144145

145-
results, metric_total, metric_unit = ret
146+
results, metric_result = ret
146147

147148
# Convert to BenchmarkResult for printing
148149
benchmark_results = BenchmarkResult[]
149150
for (impl_name, times) in results
150151
min_t = minimum(times)
151152
mean_t = sum(times) / length(times)
152-
tp = !isempty(metric_unit) ? format_throughput(metric_total, metric_unit, min_t) : ""
153+
tp = ""
154+
if metric_result isa Dict
155+
if haskey(metric_result, impl_name)
156+
mt, mu = metric_result[impl_name]
157+
tp = format_throughput(mt, mu, min_t)
158+
end
159+
elseif metric_result isa Tuple
160+
mt, mu = metric_result
161+
tp = format_throughput(mt, mu, min_t)
162+
end
153163
push!(benchmark_results, BenchmarkResult(impl_name, min_t, mean_t, tp))
154164
end
155165

examples/layernorm.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,13 @@ function test_layernorm(M, N, TILE_N; TILE_M::Int=32, eps::Float32=1f-5, name=no
369369
end
370370

371371
function metric(data)
372-
# Forward: 3 reads of X + W + B reads + Y write + Mean/Rstd writes ≈ 4*M*N floats
373-
return 4 * data.M * data.N * sizeof(Float32), "GB/s"
372+
MN = data.M * data.N * sizeof(Float32)
373+
return Dict(
374+
# Forward: X read (3 passes: mean, var, normalize) + Y write ≈ 4*M*N floats
375+
"cuTile Fwd" => (4 * MN, "GB/s"),
376+
# Backward: X read (2 passes) + DY read (2 passes) + DX write ≈ 5*M*N floats
377+
"cuTile Bwd" => (5 * MN, "GB/s"),
378+
)
374379
end
375380

376381
# No run_others for layernorm - no simple reference implementation to compare against

src/compiler/passes/pipeline.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,26 @@ const COMPARISON_RULES = RewriteRule[
189189
$(ComparisonPredicate.LessThan), $(Signedness.Signed)))
190190
]
191191

192+
#=============================================================================
193+
Power Strength Reduction (rewrite)
194+
=============================================================================#
195+
196+
# pow(x, 2) → mulf(x, x): replaces an expensive transcendental with a multiply.
197+
# The MLIR Tile IR backend has no canonicalization for pow, so this is purely
198+
# a Julia-level optimization. Applies to the variance computation in layernorm
199+
# (centered_tx .^ 2.0f0). Uses a guard with == so it matches any float type
200+
# (Float16, BFloat16, Float32, Float64, TFloat32). Integer-literal exponents
201+
# (x .^ 2) are already handled by Julia's literal_pow → x*x → mulf(x, x).
202+
203+
function is_pow_two(match, driver)
204+
c = const_value(driver.constants, match.bindings[:exp])
205+
c !== nothing && c == 2
206+
end
207+
208+
const POWER_RULES = RewriteRule[
209+
@rewrite(Intrinsics.pow(~x, ~exp) => Intrinsics.mulf(~x, ~x), is_pow_two)
210+
]
211+
192212
#=============================================================================
193213
Combined Rule Set
194214
=============================================================================#
@@ -198,6 +218,7 @@ const OPTIMIZATION_RULES = RewriteRule[
198218
ALGEBRA_RULES...,
199219
FMA_RULES...,
200220
COMPARISON_RULES...,
221+
POWER_RULES...,
201222
]
202223

203224
#=============================================================================

test/codegen/operations.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,18 +1390,31 @@ end
13901390
end
13911391
end
13921392

1393-
# scalar exponent
1393+
# scalar exponent (pow(x, 2) is strength-reduced to mulf(x, x))
13941394
@test @filecheck begin
13951395
@check_label "entry"
13961396
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a
13971397
pid = ct.bid(1)
13981398
tile = ct.load(a, pid, (16,))
1399-
@check "broadcast"
1400-
@check "pow"
1399+
@check "mulf"
1400+
@check_not "pow"
14011401
Base.donotdelete(tile .^ 2.0f0)
14021402
return
14031403
end
14041404
end
1405+
1406+
# pow2 strength reduction works for Float64 too
1407+
@test @filecheck begin
1408+
@check_label "entry"
1409+
code_tiled(Tuple{ct.TileArray{Float64,1,spec1d}}) do a
1410+
pid = ct.bid(1)
1411+
tile = ct.load(a, pid, (16,))
1412+
@check "mulf"
1413+
@check_not "pow"
1414+
Base.donotdelete(tile .^ 2.0)
1415+
return
1416+
end
1417+
end
14051418
end
14061419

14071420
@testset "scalar math functions" begin

0 commit comments

Comments
 (0)