Skip to content

Commit eb00094

Browse files
committed
[Pallas] Indirect gather via one-hot matmul codegen
1 parent 6414444 commit eb00094

4 files changed

Lines changed: 195 additions & 3 deletions

File tree

helion/_compiler/pallas/plan_tiling.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ class NonePattern(IndexingPattern):
6464
"""None index pattern (broadcasting dimension) - allow tiling."""
6565

6666

67+
@dataclass
68+
class IndirectGatherPattern(IndexingPattern):
69+
"""Pattern for table[idx_tensor, :] where idx_tensor is a runtime tensor.
70+
71+
Codegen emits one_hot(idx, V) @ table. The table's first dim gets a None
72+
BlockSpec (entire table in VMEM, no tiling on that dim).
73+
"""
74+
75+
idx_block_id: int | None = None
76+
77+
6778
@dataclass
6879
class DimensionTiling:
6980
"""Tiling decision for a specific dimension of a tensor
@@ -183,6 +194,11 @@ def _detect_indexing_pattern(
183194

184195
if isinstance(idx, torch.fx.Node):
185196
idx_val = idx.meta.get("val")
197+
if isinstance(idx_val, torch.Tensor):
198+
idx_block_id: int | None = None
199+
if idx_val.ndim >= 1:
200+
idx_block_id = env.get_block_id(idx_val.shape[0])
201+
return IndirectGatherPattern(idx_block_id=idx_block_id)
186202
if isinstance(idx_val, torch.SymInt):
187203
block_id = env.get_block_id(idx_val)
188204
if block_id is not None:
@@ -270,6 +286,9 @@ def _try_set_tiling_block_id(new_block_id: int) -> None:
270286
elif isinstance(pattern, NonePattern):
271287
pass
272288

289+
elif isinstance(pattern, IndirectGatherPattern):
290+
_disallow_tiling()
291+
273292
if isinstance(pattern, (TilePattern, TileBeginWithOffsetPattern)):
274293
block_size = env.block_sizes[pattern.block_id].from_config(config)
275294
if isinstance(block_size, int):

helion/language/memory_ops.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,102 @@ def _pallas_generated_index_code(
304304
)
305305

306306

307+
# Conservative VMEM threshold for gather tables. Emits a clear error
308+
# instead of a generic Mosaic OOM. Should be replaced with context-aware
309+
# VMEM budget accounting (e.g. querying actual capacity and other allocations).
310+
_PALLAS_GATHER_VMEM_THRESHOLD_BYTES = 16 << 20 # 16 MiB
311+
312+
313+
def _pallas_indirect_gather_positions(
314+
indexing_patterns: list[object],
315+
) -> list[int]:
316+
from .._compiler.pallas.plan_tiling import IndirectGatherPattern
317+
318+
return [
319+
i
320+
for i, p in enumerate(indexing_patterns)
321+
if isinstance(p, IndirectGatherPattern)
322+
]
323+
324+
325+
def _pallas_emit_gather_load(
326+
state: CodegenState,
327+
tensor: torch.Tensor,
328+
subscript: list[object] | tuple[object, ...],
329+
indexing_patterns: list[object],
330+
indirect_positions: list[int],
331+
name: str,
332+
) -> ast.AST:
333+
"""Emit a one-hot matmul gather: one_hot(idx, V) @ table."""
334+
from .._compiler.pallas.plan_tiling import IndirectGatherPattern
335+
336+
if len(indirect_positions) > 1:
337+
raise NotImplementedError(
338+
"Pallas backend: gather with multiple indirect dims is not supported"
339+
)
340+
indirect_pos = indirect_positions[0]
341+
if indirect_pos != 0:
342+
raise NotImplementedError(
343+
"Pallas backend: indirect gather is only supported on dim 0"
344+
)
345+
pattern = indexing_patterns[indirect_pos]
346+
assert isinstance(pattern, IndirectGatherPattern)
347+
348+
table_bytes = tensor.numel() * tensor.dtype.itemsize
349+
if (
350+
isinstance(table_bytes, int)
351+
and table_bytes > _PALLAS_GATHER_VMEM_THRESHOLD_BYTES
352+
):
353+
raise NotImplementedError(
354+
f"Pallas backend: indirect gather requires the full table in VMEM "
355+
f"({table_bytes} bytes > {_PALLAS_GATHER_VMEM_THRESHOLD_BYTES} byte "
356+
f"threshold). Tile the kernel so the gathered table fits, or use a "
357+
f"different access pattern."
358+
)
359+
360+
if not tensor.dtype.is_floating_point:
361+
raise NotImplementedError(
362+
f"Pallas backend: indirect gather requires a floating-point table, "
363+
f"got {tensor.dtype}"
364+
)
365+
366+
vocab_size = tensor.shape[0]
367+
368+
ast_subscripts = state.ast_args[1]
369+
assert isinstance(ast_subscripts, list)
370+
ast_idx = ast_subscripts[indirect_pos]
371+
assert isinstance(ast_idx, ast.AST)
372+
idx_name = state.codegen.lift(ast_idx, dce=False, prefix="index").id
373+
374+
# Collect none_dims from subscript for expand_dims after the matmul
375+
none_dims: list[int] = []
376+
for out_pos, idx in enumerate(subscript):
377+
if idx is None:
378+
none_dims.append(out_pos)
379+
380+
jnp_dtype = CompileEnvironment.current().backend.dtype_str(tensor.dtype)
381+
# TPU MXU requires 32-bit accumulator. For float32 tables we also need
382+
# Precision.HIGHEST to prevent MXU from truncating inputs to bfloat16
383+
# before multiply-accumulate. For half types the truncation is a no-op.
384+
needs_highest = tensor.dtype not in (torch.bfloat16, torch.float16)
385+
precision_arg = "precision=jax.lax.Precision.HIGHEST, " if needs_highest else ""
386+
result = expr_from_string(
387+
f"jax.lax.dot_general("
388+
f"jax.nn.one_hot({idx_name}[...], {vocab_size}, dtype=jnp.float32), "
389+
f"{name}[...].astype(jnp.float32), "
390+
f"(((1,), (0,)), ((), ())), "
391+
f"preferred_element_type=jnp.float32, "
392+
f"{precision_arg}"
393+
f").astype({jnp_dtype})"
394+
)
395+
396+
for dim in none_dims:
397+
result = expr_from_string(
398+
f"jnp.expand_dims({{result}}, axis={dim})", result=result
399+
)
400+
return result
401+
402+
307403
def _pallas_tile_pattern_code(
308404
pattern: object,
309405
idx: object,
@@ -448,6 +544,12 @@ def _(state: CodegenState) -> None:
448544
device_fn = state.device_function
449545
device_fn.device_store_index += 1
450546
device_fn.device_memory_op_index += 1
547+
indexing_patterns = _pallas_get_indexing_patterns(state, tensor)
548+
if _pallas_indirect_gather_positions(indexing_patterns):
549+
# TODO(pallas-scatter): emit one_hot(idx, V).T @ values
550+
raise NotImplementedError(
551+
"Pallas backend: indirect store (scatter) is not supported"
552+
)
451553
index_str, _ = _pallas_index_str(state, subscript, tensor)
452554
state.codegen.add_statement(
453555
statement_from_string(f"{name}[{index_str}] = {{value}}", value=value)
@@ -1507,6 +1609,14 @@ def _(state: CodegenState) -> ast.AST:
15071609
device_fn = state.device_function
15081610
device_fn.device_load_index += 1
15091611
device_fn.device_memory_op_index += 1
1612+
1613+
indexing_patterns = _pallas_get_indexing_patterns(state, tensor)
1614+
indirect_positions = _pallas_indirect_gather_positions(indexing_patterns)
1615+
if indirect_positions:
1616+
return _pallas_emit_gather_load(
1617+
state, tensor, subscript, indexing_patterns, indirect_positions, name
1618+
)
1619+
15101620
index_str, none_dims = _pallas_index_str(state, subscript, tensor)
15111621
result = expr_from_string(f"{name}[{index_str}]")
15121622
for dim in none_dims:

test/test_examples.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ def test_softmax_two_pass_block_ptr(self):
472472
indexing="block_ptr",
473473
)
474474

475-
@xfailIfPallas("missing BlockSpec for hl.load with computed indices")
476475
def test_cross_entropy(self):
477476
n, v = 128, 1000
478477
logits = torch.randn(n, v, device=DEVICE, dtype=torch.float32)
@@ -673,7 +672,6 @@ def test_rms_norm_bwd(self):
673672
atol=1e-2,
674673
)
675674

676-
@xfailIfPallas("BlockSpec tiling failure")
677675
def test_embedding_pointers(self):
678676
args = (
679677
torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),
@@ -687,7 +685,6 @@ def test_embedding_pointers(self):
687685
indexing="pointer",
688686
)
689687

690-
@xfailIfPallas("BlockSpec tiling failure")
691688
@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
692689
@skipIfTileIR("TileIR does not support block_ptr indexing")
693690
def test_embedding_block_ptr(self):

test/test_pallas.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,72 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
904904
expected = (x[:, None] < y[None, :]).to(torch.float32)
905905
torch.testing.assert_close(result, expected)
906906

907+
@staticmethod
908+
def _indirect_gather_kernel():
909+
@helion.kernel(backend="pallas", static_shapes=True)
910+
def gather(indices: torch.Tensor, table: torch.Tensor) -> torch.Tensor:
911+
out = torch.empty(
912+
[indices.size(0), table.size(1)],
913+
dtype=table.dtype,
914+
device=table.device,
915+
)
916+
for tile_b, tile_e in hl.tile([indices.size(0), table.size(1)]):
917+
out[tile_b, tile_e] = table[indices[tile_b], tile_e]
918+
return out
919+
920+
return gather
921+
922+
def test_indirect_gather_fits_vmem(self) -> None:
923+
"""Indirect gather emits one_hot matmul."""
924+
gather = self._indirect_gather_kernel()
925+
table = torch.randn(16, 64, device=DEVICE, dtype=torch.float32)
926+
indices = torch.randint(0, 16, (256,), device=DEVICE, dtype=torch.int32)
927+
code, result = code_and_output(gather, (indices, table), block_sizes=[128, 64])
928+
self.assertIn("one_hot", code)
929+
self.assertIn("HIGHEST", code)
930+
expected = table.cpu()[indices.long().cpu()]
931+
torch.testing.assert_close(result.cpu(), expected)
932+
933+
def test_indirect_gather_bf16(self) -> None:
934+
"""Indirect gather with bf16 table skips HIGHEST precision."""
935+
gather = self._indirect_gather_kernel()
936+
table = torch.randn(16, 64, device=DEVICE, dtype=torch.bfloat16)
937+
indices = torch.randint(0, 16, (256,), device=DEVICE, dtype=torch.int32)
938+
code, result = code_and_output(gather, (indices, table), block_sizes=[128, 64])
939+
self.assertIn("one_hot", code)
940+
self.assertIn("astype(jnp.bfloat16)", code)
941+
self.assertNotIn("HIGHEST", code)
942+
expected = table.cpu()[indices.long().cpu()]
943+
torch.testing.assert_close(result.cpu(), expected)
944+
945+
def test_indirect_gather_too_large_raises(self) -> None:
946+
"""Indirect gather table over the VMEM threshold raises NotImplementedError."""
947+
gather = self._indirect_gather_kernel()
948+
# 65537 * 64 * 4 bytes = 16 MiB + 256 bytes, just above the threshold.
949+
table = torch.randn(65537, 64, device=DEVICE, dtype=torch.float32)
950+
indices = torch.randint(0, 65537, (256,), device=DEVICE, dtype=torch.int32)
951+
with self.assertRaisesRegex(
952+
Exception, "indirect gather requires the full table"
953+
):
954+
code_and_output(gather, (indices, table), block_sizes=[128, 64])
955+
956+
def test_indirect_store_scatter_raises(self) -> None:
957+
"""Scatter (indirect store) is rejected with a clear error."""
958+
959+
@helion.kernel(backend="pallas", static_shapes=True)
960+
def scatter(
961+
out: torch.Tensor, values: torch.Tensor, indices: torch.Tensor
962+
) -> torch.Tensor:
963+
for tile_b, tile_e in hl.tile([values.size(0), values.size(1)]):
964+
out[indices[tile_b], tile_e] = values[tile_b, tile_e]
965+
return out
966+
967+
out = torch.zeros(16, 64, device=DEVICE, dtype=torch.float32)
968+
values = torch.randn(8, 64, device=DEVICE, dtype=torch.float32)
969+
indices = torch.arange(8, device=DEVICE, dtype=torch.int32)
970+
with self.assertRaisesRegex(Exception, "indirect store"):
971+
code_and_output(scatter, (out, values, indices), block_sizes=[8, 64])
972+
907973

908974
if __name__ == "__main__":
909975
unittest.main()

0 commit comments

Comments
 (0)