Skip to content

Commit a0e7b35

Browse files
committed
[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)
This version uses jnp.take_along_axis which is the natural JAX equivalent of torch.gather. It works in interpret mode but fails on real TPU due to a limitation in Mosaic's lax.gather lowering rule which requires indices.shape == input.shape + (1,). Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> stack-info: PR: #2061, branch: AmesingFlank/stack/27
1 parent b253dba commit a0e7b35

1 file changed

Lines changed: 6 additions & 29 deletions

File tree

helion/_compiler/aten_lowering.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,12 +1767,11 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object:
17671767

17681768
@gather_lowering.register_codegen("pallas")
17691769
def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
1770-
"""Generate gather for Pallas using one_hot + multiply + sum.
1770+
"""Generate gather for Pallas using jnp.take_along_axis.
17711771
1772-
TPU Mosaic has limited lax.gather support, so we implement
1773-
gather(input, dim, index) as:
1774-
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
1775-
result = sum(input * mask, axis=dim, keepdims=True)
1772+
This is the natural JAX equivalent of torch.gather but currently fails
1773+
on TPU Mosaic due to limited lax.gather lowering support.
1774+
See: https://github.com/jax-ml/jax/issues/XXXXX
17761775
"""
17771776
assert not node.kwargs, "gather does not support keyword arguments"
17781777
assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}"
@@ -1805,35 +1804,13 @@ def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
18051804
index_ast = _env_arg(ctx, index_node)
18061805
assert isinstance(index_ast, ast.AST)
18071806

1808-
idx_var = fn.new_var("gather_idx")
1809-
mask_var = fn.new_var("gather_mask")
18101807
result_var = fn.new_var("gather_result")
18111808

18121809
ctx.cg.add_statement(
18131810
statement_from_string(
1814-
f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})",
1815-
index=index_ast,
1816-
)
1817-
)
1818-
1819-
ctx.cg.add_statement(
1820-
statement_from_string(
1821-
f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)",
1822-
input=input_ast,
1823-
)
1824-
)
1825-
1826-
if dim != ndim - 1:
1827-
ctx.cg.add_statement(
1828-
statement_from_string(
1829-
f"{mask_var} = jnp.moveaxis({mask_var}, -1, {dim})",
1830-
)
1831-
)
1832-
1833-
ctx.cg.add_statement(
1834-
statement_from_string(
1835-
f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)",
1811+
f"{result_var} = jnp.take_along_axis({{input}}, {{index}}.astype(jnp.int32), axis={dim})",
18361812
input=input_ast,
1813+
index=index_ast,
18371814
)
18381815
)
18391816

0 commit comments

Comments
 (0)