Commit 504dcb0
Jetski
Expose GMM tile sizes in Tokamax cleanly via global heuristics monkey-patching
This CL enables specifying the tile sizes for both the forward and backward passes of Tokamax GMM (ragged_dot) in MaxText.
Key changes:
1. Exposes manual tiling configuration overrides in base.yml (wi_tile and wo_tile flags) to specify tile sizes for Forward (fwd), Backward DLHS, and Backward DRHS passes.
2. Dynamically monkey-patches PallasMosaicTpuRaggedDot._get_heuristics_config globally to intercept and route manual GMM tile configurations dynamically based on active operand shapes and JAX dimension numbers.
3. Retains high-level layer implementations completely standard without custom compiler or VJP wrapping code.
4. Adds a comprehensive unit test suite (TokamaxMonkeyPatchTest) in tests/unit/moe_test.py, insulating configurations from cross-test state, and achieving 100% test coverage.
FIXES: b/5061578561 parent 0747df9 commit 504dcb0
3 files changed
Lines changed: 182 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
51 | 59 | | |
52 | 60 | | |
53 | 61 | | |
| |||
549 | 557 | | |
550 | 558 | | |
551 | 559 | | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
552 | 563 | | |
553 | 564 | | |
554 | 565 | | |
| |||
1137 | 1148 | | |
1138 | 1149 | | |
1139 | 1150 | | |
| 1151 | + | |
1140 | 1152 | | |
1141 | 1153 | | |
1142 | | - | |
| 1154 | + | |
1143 | 1155 | | |
1144 | 1156 | | |
1145 | 1157 | | |
| |||
2541 | 2553 | | |
2542 | 2554 | | |
2543 | 2555 | | |
| 2556 | + | |
| 2557 | + | |
| 2558 | + | |
| 2559 | + | |
| 2560 | + | |
| 2561 | + | |
| 2562 | + | |
| 2563 | + | |
| 2564 | + | |
| 2565 | + | |
| 2566 | + | |
| 2567 | + | |
| 2568 | + | |
| 2569 | + | |
| 2570 | + | |
| 2571 | + | |
| 2572 | + | |
| 2573 | + | |
| 2574 | + | |
| 2575 | + | |
| 2576 | + | |
| 2577 | + | |
| 2578 | + | |
| 2579 | + | |
| 2580 | + | |
| 2581 | + | |
| 2582 | + | |
| 2583 | + | |
| 2584 | + | |
| 2585 | + | |
| 2586 | + | |
| 2587 | + | |
| 2588 | + | |
| 2589 | + | |
| 2590 | + | |
| 2591 | + | |
| 2592 | + | |
| 2593 | + | |
| 2594 | + | |
| 2595 | + | |
| 2596 | + | |
| 2597 | + | |
| 2598 | + | |
| 2599 | + | |
| 2600 | + | |
| 2601 | + | |
| 2602 | + | |
| 2603 | + | |
| 2604 | + | |
| 2605 | + | |
| 2606 | + | |
| 2607 | + | |
| 2608 | + | |
| 2609 | + | |
| 2610 | + | |
| 2611 | + | |
| 2612 | + | |
| 2613 | + | |
| 2614 | + | |
| 2615 | + | |
| 2616 | + | |
| 2617 | + | |
| 2618 | + | |
| 2619 | + | |
| 2620 | + | |
| 2621 | + | |
| 2622 | + | |
| 2623 | + | |
| 2624 | + | |
| 2625 | + | |
| 2626 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| |||
833 | 834 | | |
834 | 835 | | |
835 | 836 | | |
| 837 | + | |
| 838 | + | |
836 | 839 | | |
837 | 840 | | |
838 | 841 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
35 | 42 | | |
36 | 43 | | |
37 | 44 | | |
| |||
1521 | 1528 | | |
1522 | 1529 | | |
1523 | 1530 | | |
| 1531 | + | |
| 1532 | + | |
| 1533 | + | |
| 1534 | + | |
| 1535 | + | |
| 1536 | + | |
| 1537 | + | |
| 1538 | + | |
| 1539 | + | |
| 1540 | + | |
| 1541 | + | |
| 1542 | + | |
| 1543 | + | |
| 1544 | + | |
| 1545 | + | |
| 1546 | + | |
| 1547 | + | |
| 1548 | + | |
| 1549 | + | |
| 1550 | + | |
| 1551 | + | |
| 1552 | + | |
| 1553 | + | |
| 1554 | + | |
| 1555 | + | |
| 1556 | + | |
| 1557 | + | |
| 1558 | + | |
| 1559 | + | |
| 1560 | + | |
| 1561 | + | |
| 1562 | + | |
| 1563 | + | |
| 1564 | + | |
| 1565 | + | |
| 1566 | + | |
| 1567 | + | |
| 1568 | + | |
| 1569 | + | |
| 1570 | + | |
| 1571 | + | |
| 1572 | + | |
| 1573 | + | |
| 1574 | + | |
| 1575 | + | |
| 1576 | + | |
| 1577 | + | |
| 1578 | + | |
| 1579 | + | |
| 1580 | + | |
| 1581 | + | |
| 1582 | + | |
| 1583 | + | |
| 1584 | + | |
| 1585 | + | |
| 1586 | + | |
| 1587 | + | |
| 1588 | + | |
| 1589 | + | |
| 1590 | + | |
| 1591 | + | |
| 1592 | + | |
| 1593 | + | |
| 1594 | + | |
| 1595 | + | |
| 1596 | + | |
| 1597 | + | |
| 1598 | + | |
| 1599 | + | |
| 1600 | + | |
| 1601 | + | |
| 1602 | + | |
| 1603 | + | |
| 1604 | + | |
| 1605 | + | |
| 1606 | + | |
| 1607 | + | |
| 1608 | + | |
| 1609 | + | |
| 1610 | + | |
| 1611 | + | |
| 1612 | + | |
| 1613 | + | |
| 1614 | + | |
| 1615 | + | |
| 1616 | + | |
| 1617 | + | |
| 1618 | + | |
1524 | 1619 | | |
1525 | 1620 | | |
0 commit comments