From 9ef6bb5a17174586cf6ca12a63f6ad3bc83e61c1 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Mon, 20 Apr 2026 21:34:17 +0000 Subject: [PATCH] [Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing) This version uses jnp.take_along_axis which is the natural JAX equivalent of torch.gather. It works in interpret mode but fails on real TPU due to a limitation in Mosaic's lax.gather lowering rule which requires indices.shape == input.shape + (1,). Co-Authored-By: Claude Sonnet 4 stack-info: PR: https://github.com/pytorch/helion/pull/2061, branch: AmesingFlank/stack/27 --- helion/_compiler/aten_lowering.py | 35 ++++++------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index 3b25ed9e7..f4a85b867 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -2513,12 +2513,11 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object: @gather_lowering.register_codegen("pallas") def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object: - """Generate gather for Pallas using one_hot + multiply + sum. + """Generate gather for Pallas using jnp.take_along_axis. - TPU Mosaic has limited lax.gather support, so we implement - gather(input, dim, index) as: - mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype) - result = sum(input * mask, axis=dim, keepdims=True) + This is the natural JAX equivalent of torch.gather but currently fails + on TPU Mosaic due to limited lax.gather lowering support. + See: https://github.com/jax-ml/jax/issues/XXXXX """ assert not node.kwargs, "gather does not support keyword arguments" assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}" @@ -2551,35 +2550,13 @@ def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object: index_ast = _env_arg(ctx, index_node) assert isinstance(index_ast, ast.AST) - idx_var = fn.new_var("gather_idx") - mask_var = fn.new_var("gather_mask") result_var = fn.new_var("gather_result") ctx.cg.add_statement( statement_from_string( - f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})", - index=index_ast, - ) - ) - - ctx.cg.add_statement( - statement_from_string( - f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)", - input=input_ast, - ) - ) - - if dim != ndim - 1: - ctx.cg.add_statement( - statement_from_string( - f"{mask_var} = jnp.moveaxis({mask_var}, -1, {dim})", - ) - ) - - ctx.cg.add_statement( - statement_from_string( - f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)", + f"{result_var} = jnp.take_along_axis({{input}}, {{index}}.astype(jnp.int32), axis={dim})", input=input_ast, + index=index_ast, ) )