Commit 13ba6b2
committed
Fold gather(load(t, [..., :, ...]), dim, idx) into direct indirect load
The cross_entropy pattern (logits[tile_n, :].gather(1, idx[tile_n].unsqueeze(1)))
was producing invalid Triton (NameError on the load) when the reduction roller
tried to roll the surrounding amax/sum: a _for_loop output can't carry the
rdim-shaped logits_rows out to feed the gather sitting outside the loop.
Rewrite gather(load(t, [..., :, ...]), dim, idx) at the FX layer to a direct
indirect load(t, [..., idx, ...]). The two forms compute the same values, but
the direct form skips the wide load entirely — so the rdim-shaped intermediate
never exists and the roller's existing logic handles the surrounding reductions
naturally. The CuTe backend already does this fold at codegen time
(aten_lowering.codegen_gather_cute); lifting it to FX surfaces the same
simplification to the Triton backend and the rolling analysis.
The fold is gated to the cross_entropy-style pattern: load's dim axis is a
full slice, gather index has a singleton at dim and the same rank as the
load's subscript, no extra_mask. Other gather shapes go through the existing
aten.gather path.
After this, examples/cross_entropy.py runs end-to-end: autotuning finds rolled
configs (block_sizes=[1], reduction_loops=[16384]) and the kernel is ~3x
faster than torch eager.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
stack-info: PR: #2684, branch: AmesingFlank/stack/631 parent 8c7d65d commit 13ba6b2
2 files changed
Lines changed: 143 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2392 | 2392 | | |
2393 | 2393 | | |
2394 | 2394 | | |
| 2395 | + | |
| 2396 | + | |
2395 | 2397 | | |
2396 | 2398 | | |
2397 | 2399 | | |
| |||
2626 | 2628 | | |
2627 | 2629 | | |
2628 | 2630 | | |
| 2631 | + | |
| 2632 | + | |
| 2633 | + | |
| 2634 | + | |
| 2635 | + | |
| 2636 | + | |
| 2637 | + | |
| 2638 | + | |
| 2639 | + | |
| 2640 | + | |
| 2641 | + | |
| 2642 | + | |
| 2643 | + | |
| 2644 | + | |
| 2645 | + | |
| 2646 | + | |
| 2647 | + | |
| 2648 | + | |
| 2649 | + | |
| 2650 | + | |
| 2651 | + | |
| 2652 | + | |
| 2653 | + | |
| 2654 | + | |
| 2655 | + | |
| 2656 | + | |
| 2657 | + | |
| 2658 | + | |
| 2659 | + | |
| 2660 | + | |
| 2661 | + | |
| 2662 | + | |
| 2663 | + | |
| 2664 | + | |
| 2665 | + | |
| 2666 | + | |
| 2667 | + | |
| 2668 | + | |
| 2669 | + | |
| 2670 | + | |
| 2671 | + | |
| 2672 | + | |
| 2673 | + | |
| 2674 | + | |
| 2675 | + | |
| 2676 | + | |
| 2677 | + | |
| 2678 | + | |
| 2679 | + | |
| 2680 | + | |
| 2681 | + | |
| 2682 | + | |
| 2683 | + | |
| 2684 | + | |
| 2685 | + | |
| 2686 | + | |
| 2687 | + | |
| 2688 | + | |
| 2689 | + | |
| 2690 | + | |
| 2691 | + | |
| 2692 | + | |
| 2693 | + | |
| 2694 | + | |
| 2695 | + | |
| 2696 | + | |
| 2697 | + | |
| 2698 | + | |
| 2699 | + | |
| 2700 | + | |
| 2701 | + | |
| 2702 | + | |
| 2703 | + | |
| 2704 | + | |
| 2705 | + | |
| 2706 | + | |
| 2707 | + | |
| 2708 | + | |
| 2709 | + | |
| 2710 | + | |
| 2711 | + | |
| 2712 | + | |
| 2713 | + | |
| 2714 | + | |
| 2715 | + | |
| 2716 | + | |
| 2717 | + | |
| 2718 | + | |
| 2719 | + | |
| 2720 | + | |
| 2721 | + | |
| 2722 | + | |
| 2723 | + | |
| 2724 | + | |
| 2725 | + | |
| 2726 | + | |
| 2727 | + | |
| 2728 | + | |
| 2729 | + | |
| 2730 | + | |
| 2731 | + | |
| 2732 | + | |
| 2733 | + | |
| 2734 | + | |
2629 | 2735 | | |
2630 | 2736 | | |
2631 | 2737 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
| |||
2467 | 2468 | | |
2468 | 2469 | | |
2469 | 2470 | | |
| 2471 | + | |
| 2472 | + | |
| 2473 | + | |
| 2474 | + | |
| 2475 | + | |
| 2476 | + | |
| 2477 | + | |
| 2478 | + | |
| 2479 | + | |
| 2480 | + | |
| 2481 | + | |
| 2482 | + | |
| 2483 | + | |
| 2484 | + | |
| 2485 | + | |
| 2486 | + | |
| 2487 | + | |
| 2488 | + | |
| 2489 | + | |
| 2490 | + | |
| 2491 | + | |
| 2492 | + | |
| 2493 | + | |
| 2494 | + | |
| 2495 | + | |
| 2496 | + | |
| 2497 | + | |
| 2498 | + | |
| 2499 | + | |
| 2500 | + | |
| 2501 | + | |
| 2502 | + | |
| 2503 | + | |
| 2504 | + | |
| 2505 | + | |
| 2506 | + | |
2470 | 2507 | | |
2471 | 2508 | | |
2472 | 2509 | | |
| |||
0 commit comments