diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index 3b25ed9e7..f4a85b867 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -2513,12 +2513,11 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object: @gather_lowering.register_codegen("pallas") def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object: - """Generate gather for Pallas using one_hot + multiply + sum. + """Generate gather for Pallas using jnp.take_along_axis. - TPU Mosaic has limited lax.gather support, so we implement - gather(input, dim, index) as: - mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype) - result = sum(input * mask, axis=dim, keepdims=True) + This is the natural JAX equivalent of torch.gather but currently fails + on TPU Mosaic due to limited lax.gather lowering support. + See: https://github.com/jax-ml/jax/issues/XXXXX """ assert not node.kwargs, "gather does not support keyword arguments" assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}" @@ -2551,35 +2550,13 @@ def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object: index_ast = _env_arg(ctx, index_node) assert isinstance(index_ast, ast.AST) - idx_var = fn.new_var("gather_idx") - mask_var = fn.new_var("gather_mask") result_var = fn.new_var("gather_result") ctx.cg.add_statement( statement_from_string( - f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})", - index=index_ast, - ) - ) - - ctx.cg.add_statement( - statement_from_string( - f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)", - input=input_ast, - ) - ) - - if dim != ndim - 1: - ctx.cg.add_statement( - statement_from_string( - f"{mask_var} = jnp.moveaxis({mask_var}, -1, {dim})", - ) - ) - - ctx.cg.add_statement( - statement_from_string( - f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)", + f"{result_var} = jnp.take_along_axis({{input}}, {{index}}.astype(jnp.int32), axis={dim})", input=input_ast, + index=index_ast, ) )