Commit 696b52e
committed
[Pallas] Lower aten gather using one_hot + sum for TPU compatibility
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
result = sum(input * mask, axis=dim, keepdims=True)
Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
stack-info: PR: #2060, branch: AmesingFlank/stack/261 parent c95b79f commit 696b52e
2 files changed
Lines changed: 68 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1765 | 1765 | | |
1766 | 1766 | | |
1767 | 1767 | | |
| 1768 | + | |
| 1769 | + | |
| 1770 | + | |
| 1771 | + | |
| 1772 | + | |
| 1773 | + | |
| 1774 | + | |
| 1775 | + | |
| 1776 | + | |
| 1777 | + | |
| 1778 | + | |
| 1779 | + | |
| 1780 | + | |
| 1781 | + | |
| 1782 | + | |
| 1783 | + | |
| 1784 | + | |
| 1785 | + | |
| 1786 | + | |
| 1787 | + | |
| 1788 | + | |
| 1789 | + | |
| 1790 | + | |
| 1791 | + | |
| 1792 | + | |
| 1793 | + | |
| 1794 | + | |
| 1795 | + | |
| 1796 | + | |
| 1797 | + | |
| 1798 | + | |
| 1799 | + | |
| 1800 | + | |
| 1801 | + | |
| 1802 | + | |
| 1803 | + | |
| 1804 | + | |
| 1805 | + | |
| 1806 | + | |
| 1807 | + | |
| 1808 | + | |
| 1809 | + | |
| 1810 | + | |
| 1811 | + | |
| 1812 | + | |
| 1813 | + | |
| 1814 | + | |
| 1815 | + | |
| 1816 | + | |
| 1817 | + | |
| 1818 | + | |
| 1819 | + | |
| 1820 | + | |
| 1821 | + | |
| 1822 | + | |
| 1823 | + | |
| 1824 | + | |
| 1825 | + | |
| 1826 | + | |
| 1827 | + | |
| 1828 | + | |
| 1829 | + | |
| 1830 | + | |
| 1831 | + | |
| 1832 | + | |
| 1833 | + | |
| 1834 | + | |
| 1835 | + | |
1768 | 1836 | | |
1769 | 1837 | | |
1770 | 1838 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
472 | 472 | | |
473 | 473 | | |
474 | 474 | | |
475 | | - | |
476 | 475 | | |
477 | 476 | | |
478 | 477 | | |
| |||
0 commit comments