Skip to content

Commit 696b52e

Browse files
committed
[Pallas] Lower aten gather using one_hot + sum for TPU compatibility
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis fails during lowering. Instead, implement gather(input, dim, index) as: mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype) result = sum(input * mask, axis=dim, keepdims=True) Also removes the xfailIfPallas mark from test_cross_entropy since the gather lowering now works. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> stack-info: PR: #2060, branch: AmesingFlank/stack/26
1 parent c95b79f commit 696b52e

2 files changed

Lines changed: 68 additions & 1 deletion

File tree

helion/_compiler/aten_lowering.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,74 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object:
17651765
return expr_from_string(result_var)
17661766

17671767

1768+
@gather_lowering.register_codegen("pallas")
1769+
def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
1770+
"""Generate gather for Pallas using one_hot + multiply + sum.
1771+
1772+
TPU Mosaic has limited lax.gather support, so we implement
1773+
gather(input, dim, index) as:
1774+
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
1775+
result = sum(input * mask, axis=dim, keepdims=True)
1776+
"""
1777+
assert not node.kwargs, "gather does not support keyword arguments"
1778+
assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}"
1779+
1780+
input_node = node.args[0]
1781+
dim = node.args[1]
1782+
index_node = node.args[2]
1783+
1784+
assert isinstance(input_node, Node), "gather input must be a Node"
1785+
assert isinstance(dim, int), f"gather dim must be int, got {type(dim)}"
1786+
assert isinstance(index_node, Node), "gather index must be a Node"
1787+
1788+
input_tensor = input_node.meta["val"]
1789+
assert isinstance(input_tensor, torch.Tensor), (
1790+
f"gather input must be a tensor, got {type(input_tensor)}"
1791+
)
1792+
1793+
ndim = input_tensor.ndim
1794+
if dim < 0:
1795+
dim = ndim + dim
1796+
assert 0 <= dim < ndim, (
1797+
f"gather dim {dim} out of range for tensor with {ndim} dimensions"
1798+
)
1799+
1800+
fn = ctx.cg.device_function
1801+
1802+
input_ast = _env_arg(ctx, input_node)
1803+
assert isinstance(input_ast, ast.AST)
1804+
1805+
index_ast = _env_arg(ctx, index_node)
1806+
assert isinstance(index_ast, ast.AST)
1807+
1808+
idx_var = fn.new_var("gather_idx")
1809+
mask_var = fn.new_var("gather_mask")
1810+
result_var = fn.new_var("gather_result")
1811+
1812+
ctx.cg.add_statement(
1813+
statement_from_string(
1814+
f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})",
1815+
index=index_ast,
1816+
)
1817+
)
1818+
1819+
ctx.cg.add_statement(
1820+
statement_from_string(
1821+
f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)",
1822+
input=input_ast,
1823+
)
1824+
)
1825+
1826+
ctx.cg.add_statement(
1827+
statement_from_string(
1828+
f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)",
1829+
input=input_ast,
1830+
)
1831+
)
1832+
1833+
return expr_from_string(result_var)
1834+
1835+
17681836
topk_lowering = register_lowering(torch.ops.aten.topk.default)
17691837

17701838

test/test_examples.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ def test_softmax_two_pass_block_ptr(self):
472472
indexing="block_ptr",
473473
)
474474

475-
@xfailIfPallas("missing BlockSpec for hl.load with computed indices")
476475
def test_cross_entropy(self):
477476
n, v = 128, 1000
478477
logits = torch.randn(n, v, device=DEVICE, dtype=torch.float32)

0 commit comments

Comments
 (0)