Skip to content

Fold gather(load(t, [..., :, ...]), dim, idx) into direct indirect load#2684

Open
AmesingFlank wants to merge 1 commit into
mainfrom
AmesingFlank/stack/63
Open

Fold gather(load(t, [..., :, ...]), dim, idx) into direct indirect load#2684
AmesingFlank wants to merge 1 commit into
mainfrom
AmesingFlank/stack/63

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Jun 4, 2026

Stacked PRs:


Fold gather(load(t, [..., :, ...]), dim, idx) into direct indirect load

The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1)))
was producing invalid Triton (NameError on the load) when the reduction roller
tried to roll the surrounding amax/sum: a _for_loop output can't carry the
rdim-shaped logits_rows out to feed the gather sitting outside the loop.

Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct
indirect load(t, [..., idx, ...]). The two forms compute the same values, but
the direct form skips the wide load entirely — so the rdim-shaped intermediate
never exists and the roller's existing logic handles the surrounding reductions
naturally. The CuTe backend already does this fold at codegen time
(aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same
simplification to the Triton backend and the rolling analysis.

The fold is gated to the cross_entropy-style pattern: load's dim axis is a
full slice, gather index has a singleton at dim and the same rank as the
load's subscript, no extra_mask. Other gather shapes go through the existing
aten.gather path.

After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled
configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x
faster than torch eager.

Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com

@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/63 branch from 1b36076 to b70396c Compare June 4, 2026 01:37
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 4, 2026
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 01:50
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 01:50
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1)))
was producing invalid Triton (NameError on the load) when the reduction roller
tried to roll the surrounding amax/sum: a _for_loop output can't carry the
rdim-shaped logits_rows out to feed the gather sitting outside the loop.

Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct
indirect load(t, [..., idx, ...]). The two forms compute the same values, but
the direct form skips the wide load entirely — so the rdim-shaped intermediate
never exists and the roller's existing logic handles the surrounding reductions
naturally. The CuTe backend already does this fold at codegen time
(aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same
simplification to the Triton backend and the rolling analysis.

The fold is gated to the cross_entropy-style pattern: load's dim axis is a
full slice, gather index has a singleton at dim and the same rank as the
load's subscript, no extra_mask. Other gather shapes go through the existing
aten.gather path.

After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled
configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x
faster than torch eager.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

stack-info: PR: #2684, branch: AmesingFlank/stack/63
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 01:52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/63 branch from b70396c to 674c822 Compare June 4, 2026 01:52
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 01:52
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1)))
was producing invalid Triton (NameError on the load) when the reduction roller
tried to roll the surrounding amax/sum: a _for_loop output can't carry the
rdim-shaped logits_rows out to feed the gather sitting outside the loop.

Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct
indirect load(t, [..., idx, ...]). The two forms compute the same values, but
the direct form skips the wide load entirely — so the rdim-shaped intermediate
never exists and the roller's existing logic handles the surrounding reductions
naturally. The CuTe backend already does this fold at codegen time
(aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same
simplification to the Triton backend and the rolling analysis.

The fold is gated to the cross_entropy-style pattern: load's dim axis is a
full slice, gather index has a singleton at dim and the same rank as the
load's subscript, no extra_mask. Other gather shapes go through the existing
aten.gather path.

After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled
configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x
faster than torch eager.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

stack-info: PR: #2684, branch: AmesingFlank/stack/63
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 01:59
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/63 branch from 674c822 to 190a1a3 Compare June 4, 2026 01:59
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 01:59
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 02:06
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 02:07
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 02:12
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 02:12
@AmesingFlank AmesingFlank requested review from choijon5 and jansel June 4, 2026 02:52
The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1)))
was producing invalid Triton (NameError on the load) when the reduction roller
tried to roll the surrounding amax/sum: a _for_loop output can't carry the
rdim-shaped logits_rows out to feed the gather sitting outside the loop.

Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct
indirect load(t, [..., idx, ...]). The two forms compute the same values, but
the direct form skips the wide load entirely — so the rdim-shaped intermediate
never exists and the roller's existing logic handles the surrounding reductions
naturally. The CuTe backend already does this fold at codegen time
(aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same
simplification to the Triton backend and the rolling analysis.

The fold is gated to the cross_entropy-style pattern: load's dim axis is a
full slice, gather index has a singleton at dim and the same rank as the
load's subscript, no extra_mask. Other gather shapes go through the existing
aten.gather path.

After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled
configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x
faster than torch eager.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

stack-info: PR: #2684, branch: AmesingFlank/stack/63
@AmesingFlank AmesingFlank marked this pull request as draft June 4, 2026 05:11
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/63 branch from 190a1a3 to 13ba6b2 Compare June 4, 2026 05:11
@AmesingFlank AmesingFlank marked this pull request as ready for review June 4, 2026 05:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants