Skip to content

Commit 55f875b

Browse files
committed
[Pallas] Lower aten gather using one_hot + sum for TPU compatibility, unblocking cross_entropy
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis fails during lowering. Instead, implement gather(input, dim, index) as: mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype) result = sum(input * mask, axis=dim, keepdims=True) Also removes the xfailIfPallas mark from test_cross_entropy since the gather lowering now works. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> stack-info: PR: #2060, branch: AmesingFlank/stack/26
1 parent 1ad88fd commit 55f875b

2 files changed

Lines changed: 170 additions & 0 deletions

File tree

helion/_compiler/aten_lowering.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,6 +2511,81 @@ def codegen_gather(ctx: LoweringContext, node: Node) -> object:
25112511
return expr_from_string(result_var)
25122512

25132513

2514+
@gather_lowering.register_codegen("pallas")
2515+
def codegen_gather_pallas(ctx: LoweringContext, node: Node) -> object:
2516+
"""Generate gather for Pallas using one_hot + multiply + sum.
2517+
2518+
TPU Mosaic has limited lax.gather support, so we implement
2519+
gather(input, dim, index) as:
2520+
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
2521+
result = sum(input * mask, axis=dim, keepdims=True)
2522+
"""
2523+
assert not node.kwargs, "gather does not support keyword arguments"
2524+
assert len(node.args) == 3, f"gather expects 3 arguments, got {len(node.args)}"
2525+
2526+
input_node = node.args[0]
2527+
dim = node.args[1]
2528+
index_node = node.args[2]
2529+
2530+
assert isinstance(input_node, Node), "gather input must be a Node"
2531+
assert isinstance(dim, int), f"gather dim must be int, got {type(dim)}"
2532+
assert isinstance(index_node, Node), "gather index must be a Node"
2533+
2534+
input_tensor = input_node.meta["val"]
2535+
assert isinstance(input_tensor, torch.Tensor), (
2536+
f"gather input must be a tensor, got {type(input_tensor)}"
2537+
)
2538+
2539+
ndim = input_tensor.ndim
2540+
if dim < 0:
2541+
dim = ndim + dim
2542+
assert 0 <= dim < ndim, (
2543+
f"gather dim {dim} out of range for tensor with {ndim} dimensions"
2544+
)
2545+
2546+
fn = ctx.cg.device_function
2547+
2548+
input_ast = _env_arg(ctx, input_node)
2549+
assert isinstance(input_ast, ast.AST)
2550+
2551+
index_ast = _env_arg(ctx, index_node)
2552+
assert isinstance(index_ast, ast.AST)
2553+
2554+
idx_var = fn.new_var("gather_idx")
2555+
mask_var = fn.new_var("gather_mask")
2556+
result_var = fn.new_var("gather_result")
2557+
2558+
ctx.cg.add_statement(
2559+
statement_from_string(
2560+
f"{idx_var} = jnp.squeeze({{index}}.astype(jnp.int32), axis={dim})",
2561+
index=index_ast,
2562+
)
2563+
)
2564+
2565+
ctx.cg.add_statement(
2566+
statement_from_string(
2567+
f"{mask_var} = jax.nn.one_hot({idx_var}, {{input}}.shape[{dim}], dtype={{input}}.dtype)",
2568+
input=input_ast,
2569+
)
2570+
)
2571+
2572+
if dim != ndim - 1:
2573+
ctx.cg.add_statement(
2574+
statement_from_string(
2575+
f"{mask_var} = jnp.moveaxis({mask_var}, -1, {dim})",
2576+
)
2577+
)
2578+
2579+
ctx.cg.add_statement(
2580+
statement_from_string(
2581+
f"{result_var} = jnp.sum({{input}} * {mask_var}, axis={dim}, keepdims=True)",
2582+
input=input_ast,
2583+
)
2584+
)
2585+
2586+
return expr_from_string(result_var)
2587+
2588+
25142589
@gather_lowering.register_codegen("cute")
25152590
def codegen_gather_cute(ctx: LoweringContext, node: Node) -> object:
25162591
assert not node.kwargs, "gather does not support keyword arguments"

test/test_pallas.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3466,6 +3466,101 @@ def k(x: torch.Tensor) -> torch.Tensor:
34663466
inner_min = spec.block_sizes[1].min_size
34673467
self.assertGreaterEqual(outer_min, inner_min)
34683468

3469+
def test_gather_2d_dim_1(self) -> None:
3470+
@helion.kernel(
3471+
backend="pallas",
3472+
static_shapes=True,
3473+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
3474+
)
3475+
def fn(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3476+
n, _v = x.shape
3477+
out = torch.zeros([n, 1], dtype=x.dtype, device=x.device)
3478+
for tile_n in hl.tile(n):
3479+
out[tile_n, :] = x[tile_n, :].gather(1, idx[tile_n, :])
3480+
return out
3481+
3482+
x = torch.randn(64, 256, device=DEVICE, dtype=torch.float32)
3483+
idx = torch.randint(0, 256, (64, 1), device=DEVICE, dtype=torch.int32)
3484+
code, result = code_and_output(fn, (x, idx), block_size=64)
3485+
expected = x.gather(1, idx.long())
3486+
torch.testing.assert_close(result, expected)
3487+
3488+
def test_gather_2d_dim_0(self) -> None:
3489+
@helion.kernel(
3490+
backend="pallas",
3491+
static_shapes=True,
3492+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
3493+
)
3494+
def fn(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3495+
_n, m = x.shape
3496+
out = torch.zeros([1, m], dtype=x.dtype, device=x.device)
3497+
for tile_m in hl.tile(m):
3498+
out[:, tile_m] = x[:, tile_m].gather(0, idx[:, tile_m])
3499+
return out
3500+
3501+
x = torch.randn(128, 64, device=DEVICE, dtype=torch.float32)
3502+
idx = torch.randint(0, 128, (1, 64), device=DEVICE, dtype=torch.int32)
3503+
code, result = code_and_output(fn, (x, idx), block_size=64)
3504+
expected = x.gather(0, idx.long())
3505+
torch.testing.assert_close(result, expected)
3506+
3507+
def test_gather_3d_dim_0(self) -> None:
3508+
@helion.kernel(
3509+
backend="pallas",
3510+
static_shapes=True,
3511+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
3512+
)
3513+
def fn(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3514+
_n, m, k = x.shape
3515+
out = torch.zeros([1, m, k], dtype=x.dtype, device=x.device)
3516+
for tile_m in hl.tile(m):
3517+
out[:, tile_m, :] = x[:, tile_m, :].gather(0, idx[:, tile_m, :])
3518+
return out
3519+
3520+
x = torch.randn(32, 16, 8, device=DEVICE, dtype=torch.float32)
3521+
idx = torch.randint(0, 32, (1, 16, 8), device=DEVICE, dtype=torch.int32)
3522+
code, result = code_and_output(fn, (x, idx), block_size=16)
3523+
expected = x.gather(0, idx.long())
3524+
torch.testing.assert_close(result, expected)
3525+
3526+
def test_gather_3d_dim_1(self) -> None:
3527+
@helion.kernel(
3528+
backend="pallas",
3529+
static_shapes=True,
3530+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
3531+
)
3532+
def fn(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3533+
n, _m, k = x.shape
3534+
out = torch.zeros([n, 1, k], dtype=x.dtype, device=x.device)
3535+
for tile_n in hl.tile(n):
3536+
out[tile_n, :, :] = x[tile_n, :, :].gather(1, idx[tile_n, :, :])
3537+
return out
3538+
3539+
x = torch.randn(16, 32, 8, device=DEVICE, dtype=torch.float32)
3540+
idx = torch.randint(0, 32, (16, 1, 8), device=DEVICE, dtype=torch.int32)
3541+
code, result = code_and_output(fn, (x, idx), block_size=16)
3542+
expected = x.gather(1, idx.long())
3543+
torch.testing.assert_close(result, expected)
3544+
3545+
def test_gather_3d_dim_2(self) -> None:
3546+
@helion.kernel(
3547+
backend="pallas",
3548+
static_shapes=True,
3549+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
3550+
)
3551+
def fn(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3552+
n, m, _k = x.shape
3553+
out = torch.zeros([n, m, 1], dtype=x.dtype, device=x.device)
3554+
for tile_n in hl.tile(n):
3555+
out[tile_n, :, :] = x[tile_n, :, :].gather(2, idx[tile_n, :, :])
3556+
return out
3557+
3558+
x = torch.randn(16, 8, 64, device=DEVICE, dtype=torch.float32)
3559+
idx = torch.randint(0, 64, (16, 8, 1), device=DEVICE, dtype=torch.int32)
3560+
code, result = code_and_output(fn, (x, idx), block_size=16)
3561+
expected = x.gather(2, idx.long())
3562+
torch.testing.assert_close(result, expected)
3563+
34693564

34703565
@skipUnlessPallas("JAX/Pallas TPU not available")
34713566
class TestPallasIndirectGather(TestCase):

0 commit comments

Comments
 (0)