Skip to content

Commit 6c9976c

Browse files
committed
[LayerNorm] Tune cluster for fp32 fwd
1 parent b32988f commit 6c9976c

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

quack/rmsnorm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,14 @@ def _set_cluster_n(self):
6464
thresholds = [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]
6565
elif const_expr(self.dtype.width == 16):
6666
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
67+
elif self.is_layernorm:
68+
# fp32 layernorm: bump cluster earlier than fp16/bf16. The 2-pass path's
69+
# single-CTA tile is bandwidth-limited at N=16k/32k; cluster_n=2 splits
70+
# the row across two CTAs and recovers ~3-14% at those sizes.
71+
thresholds = [(8 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
6772
else:
73+
# fp32 rmsnorm (1-pass) is already saturated at cluster_n=1 for N<=32k;
74+
# bumping to cluster_n=2 there regresses ~3%.
6875
thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
6976
for limit, cluster in thresholds:
7077
if N <= limit:

0 commit comments

Comments
 (0)