Skip to content

Commit 25dcb32

Browse files
committed
[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)
This version uses jnp.take_along_axis which is the natural JAX equivalent of torch.gather. It works in interpret mode but fails on real TPU due to a limitation in Mosaic's lax.gather lowering rule which requires indices.shape == input.shape + (1,). Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> stack-info: PR: #2061, branch: AmesingFlank/stack/27
1 parent 55f875b commit 25dcb32

1 file changed

Lines changed: 6 additions & 29 deletions

File tree

helion/_compiler/aten_lowering.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2513,12 +2513,11 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object:
25132513

25142514
@gather_lowering.register_codegen("pallas")
25152515
def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
2516-
"""Generate gather for Pallas using one_hot + multiply + sum.
2516+
"""Generate gather for Pallas using jnp.take_along_axis.
25172517
2518-
TPU Mosaic has limited lax.gather support, so we implement
2519-
gather(input, dim, index) as:
2520-
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
2521-
result = sum(input * mask, axis=dim, keepdims=True)
2518+
This is the natural JAX equivalent of torch.gather but currently fails
2519+
on TPU Mosaic due to limited lax.gather lowering support.
2520+
See: https://github.com/jax-ml/jax/issues/XXXXX
25222521
"""
25232522
assert not node.kwargs, "gather does not support keyword arguments"
25242523
assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}"
@@ -2551,35 +2550,13 @@ def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
25512550
index_ast = _env_arg(ctx, index_node)
25522551
assert isinstance(index_ast, ast.AST)
25532552

2554-
idx_var = fn.new_var("gather_idx")
2555-
mask_var = fn.new_var("gather_mask")
25562553
result_var = fn.new_var("gather_result")
25572554

25582555
ctx.cg.add_statement(
25592556
statement_from_string(
2560-
f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})",
2561-
index=index_ast,
2562-
)
2563-
)
2564-
2565-
ctx.cg.add_statement(
2566-
statement_from_string(
2567-
f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)",
2568-
input=input_ast,
2569-
)
2570-
)
2571-
2572-
if dim != ndim - 1:
2573-
ctx.cg.add_statement(
2574-
statement_from_string(
2575-
f"{mask_var} = jnp.moveaxis({mask_var}, -1, {dim})",
2576-
)
2577-
)
2578-
2579-
ctx.cg.add_statement(
2580-
statement_from_string(
2581-
f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)",
2557+
f"{result_var} = jnp.take_along_axis({{input}}, {{index}}.astype(jnp.int32), axis={dim})",
25822558
input=input_ast,
2559+
index=index_ast,
25832560
)
25842561
)
25852562

0 commit comments

Comments
 (0)