Skip to content

Commit b70396c

Browse files
committed
Fold gather(load(t, [..., :, ...]), dim, idx) into direct indirect load
The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1))) was producing invalid Triton (NameError on the load) when the reduction roller tried to roll the surrounding amax/sum: a _for_loop output can't carry the rdim-shaped logits_rows out to feed the gather sitting outside the loop. Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct indirect load(t, [..., idx, ...]). The two forms compute the same values, but the direct form skips the wide load entirely — so the rdim-shaped intermediate never exists and the roller's existing logic handles the surrounding reductions naturally. The CuTe backend already does this fold at codegen time (aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same simplification to the Triton backend and the rolling analysis. The fold is gated to the cross_entropy-style pattern: load's dim axis is a full slice, gather index has a singleton at dim and the same rank as the load's subscript, no extra_mask. Other gather shapes go through the existing aten.gather path. After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x faster than torch eager. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> stack-info: PR: #2684, branch: AmesingFlank/stack/63
1 parent a012080 commit b70396c

2 files changed

Lines changed: 135 additions & 0 deletions

File tree

helion/_compiler/device_ir.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,8 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
23922392

23932393
for graph in device_ir.graphs:
23942394
rewrite_implicit_random_ops(graph.graph)
2395+
for graph in device_ir.graphs:
2396+
fold_gather_into_load(graph.graph)
23952397
if CompileEnvironment.current().backend.name == "cute":
23962398
promotions = collect_cute_half_atomic_output_promotions(device_ir.graphs)
23972399
if promotions:
@@ -2626,6 +2628,104 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
26262628
graph.erase_node(node)
26272629

26282630

2631+
def fold_gather_into_load(graph: torch.fx.Graph) -> None:
2632+
"""Rewrite ``gather(load(t, [..., :, ...]), dim, idx)`` to a direct
2633+
indirect ``load(t, [..., idx, ...])`` that picks the gathered elements
2634+
in one shot.
2635+
2636+
The two forms compute the same values, but the direct form skips the
2637+
full-axis load that the gather output indexes into. That matters in
2638+
two ways:
2639+
2640+
* the original load may be too wide to fit Triton's per-tile element
2641+
cap (the cross_entropy ``logits[tile_n, :]`` case), and
2642+
* if that axis is the reduction axis, the original load produces an
2643+
rdim-shaped value that the reduction roller can't carry out of its
2644+
``_for_loop`` to feed the gather (the source of the cross_entropy
2645+
``NameError: <load> is not defined`` codegen failure).
2646+
2647+
Folding the gather away removes both problems before the rolling
2648+
analysis runs. The CuTe backend already applies the same fold at
2649+
codegen time (see ``aten_lowering.codegen_gather_cute``); this pass
2650+
lifts it to the FX layer so the Triton backend and the roller see the
2651+
rewritten graph.
2652+
2653+
The original load is preserved for any non-gather users (e.g. a
2654+
sibling reduction over the same axis) — they keep their wide view of
2655+
the tensor while the gather gets its narrow direct load.
2656+
2657+
The fold only fires when:
2658+
2659+
* the load's ``dim`` axis is a full ``slice(None)`` (so the gather is
2660+
genuinely picking one element from a wide axis — for narrow
2661+
already-indexed axes the fold would produce a different result
2662+
because Helion's indirect-load pairs tensor indexers elementwise
2663+
rather than taking a Cartesian product),
2664+
* the gather index has a singleton at ``dim`` and the same rank as
2665+
the load's subscript (the cross_entropy
2666+
``idx[tile_n].unsqueeze(1)`` pattern). Other gather index shapes
2667+
— e.g. one that picks ``K`` elements per row — broadcast
2668+
differently in Helion's indirect-load codegen, so we leave those
2669+
for the existing ``aten.gather`` path, and
2670+
* the original load has no ``extra_mask`` (a mask sized to the wide
2671+
subscript would no longer match the post-fold narrow shape).
2672+
"""
2673+
for gather in graph.find_nodes(
2674+
op="call_function", target=torch.ops.aten.gather.default
2675+
):
2676+
if gather.kwargs or len(gather.args) != 3:
2677+
continue
2678+
load_node, dim, index_node = gather.args
2679+
if not (
2680+
isinstance(load_node, torch.fx.Node)
2681+
and isinstance(index_node, torch.fx.Node)
2682+
and isinstance(dim, int)
2683+
and load_node.target is hl.load
2684+
):
2685+
continue
2686+
tensor_node, subscript, *load_tail = load_node.args
2687+
if not isinstance(subscript, (list, tuple)):
2688+
continue
2689+
ndim = len(subscript)
2690+
if dim < 0:
2691+
dim += ndim
2692+
if not (0 <= dim < ndim):
2693+
continue
2694+
# Original load's gather axis must be a full slice — otherwise the
2695+
# fold doesn't preserve semantics (see docstring).
2696+
if not (isinstance(subscript[dim], slice) and subscript[dim] == slice(None)):
2697+
continue
2698+
# Forwarding a non-None ``extra_mask`` sized for the wide load
2699+
# would mismatch the narrow post-fold shape. ``eviction_policy``
2700+
# is just a string and is fine to forward.
2701+
extra_mask = load_tail[0] if load_tail else None
2702+
if extra_mask is not None:
2703+
continue
2704+
index_val = index_node.meta.get("val")
2705+
if not (
2706+
isinstance(index_val, torch.Tensor)
2707+
and index_val.ndim == ndim
2708+
and CompileEnvironment.current().size_hint(index_val.size(dim)) == 1
2709+
):
2710+
continue
2711+
new_subscript = [
2712+
(index_node if i == dim else s) for i, s in enumerate(subscript)
2713+
]
2714+
with graph.inserting_before(gather):
2715+
new_load = graph.call_function(
2716+
hl.load,
2717+
(tensor_node, new_subscript, *load_tail),
2718+
{},
2719+
)
2720+
# The new load's value matches the gather's shape/dtype, so reuse
2721+
# gather.meta (val, lowering, etc.) verbatim.
2722+
new_load.meta.update(gather.meta)
2723+
gather.replace_all_uses_with(new_load)
2724+
graph.erase_node(gather)
2725+
if not load_node.users:
2726+
graph.erase_node(load_node)
2727+
2728+
26292729
def collect_cute_half_atomic_output_promotions(
26302730
graph_infos: list[GraphInfo],
26312731
) -> dict[str, torch.dtype]:

test/test_indexing.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,41 @@ def test_gather(
24672467

24682468
torch.testing.assert_close(result, expected)
24692469

2470+
@skipIfTileIR("TileIR does not support gather operation")
2471+
def test_gather_with_rdim_reduction(self):
2472+
"""torch.gather over the reduction dim of an implicitly rolled load.
2473+
2474+
Mirrors the cross_entropy pattern: ``logits[tile_n, :]`` feeds both
2475+
``torch.gather(..., 1, idx)`` and ``torch.amax(..., dim=-1)``. With
2476+
a rolled reduction, the gather (a non-reduction op consuming the
2477+
rdim) would live outside the loop and only see the last iteration's
2478+
chunk — Triton then rejected the generated code with
2479+
``NameError: <load> is not defined``. V is large enough that the
2480+
heuristic's default config picks a rolled reduction; with the
2481+
roller's pre-pass refusing to roll this kernel, the default falls
2482+
back to a persistent reduction that compiles cleanly.
2483+
"""
2484+
2485+
@helion.kernel()
2486+
def gather_then_reduce(
2487+
x: torch.Tensor, # [N, V]
2488+
idx: torch.Tensor, # [N]
2489+
) -> torch.Tensor: # [N]
2490+
n, _v = x.shape
2491+
out = torch.empty([n], dtype=x.dtype, device=x.device)
2492+
for tile_n in hl.tile(n):
2493+
row = x[tile_n, :]
2494+
gathered = row.gather(1, idx[tile_n].unsqueeze(1)).squeeze(1)
2495+
out[tile_n] = torch.amax(row, dim=-1) - gathered
2496+
return out
2497+
2498+
n, v = 4, 8192
2499+
x = torch.randn(n, v, device=DEVICE, dtype=torch.float32)
2500+
idx = torch.randint(0, v, (n,), device=DEVICE, dtype=torch.int64)
2501+
_, result = code_and_output(gather_then_reduce, (x, idx))
2502+
expected = torch.amax(x, dim=-1) - x.gather(1, idx[:, None]).squeeze(1)
2503+
torch.testing.assert_close(result, expected)
2504+
24702505
@skipIfTileIR("TileIR does not support gather operation")
24712506
def test_gather_2d_dim0(self):
24722507
@helion.kernel()

0 commit comments

Comments
 (0)