@@ -304,6 +304,102 @@ def _pallas_generated_index_code(
304304 )
305305
306306
307+ # Conservative VMEM threshold for gather tables. Emits a clear error
308+ # instead of a generic Mosaic OOM. Should be replaced with context-aware
309+ # VMEM budget accounting (e.g. querying actual capacity and other allocations).
310+ _PALLAS_GATHER_VMEM_THRESHOLD_BYTES = 16 << 20 # 16 MiB
311+
312+
313+ def _pallas_indirect_gather_positions (
314+ indexing_patterns : list [object ],
315+ ) -> list [int ]:
316+ from .._compiler .pallas .plan_tiling import IndirectGatherPattern
317+
318+ return [
319+ i
320+ for i , p in enumerate (indexing_patterns )
321+ if isinstance (p , IndirectGatherPattern )
322+ ]
323+
324+
325+ def _pallas_emit_gather_load (
326+ state : CodegenState ,
327+ tensor : torch .Tensor ,
328+ subscript : list [object ] | tuple [object , ...],
329+ indexing_patterns : list [object ],
330+ indirect_positions : list [int ],
331+ name : str ,
332+ ) -> ast .AST :
333+ """Emit a one-hot matmul gather: one_hot(idx, V) @ table."""
334+ from .._compiler .pallas .plan_tiling import IndirectGatherPattern
335+
336+ if len (indirect_positions ) > 1 :
337+ raise NotImplementedError (
338+ "Pallas backend: gather with multiple indirect dims is not supported"
339+ )
340+ indirect_pos = indirect_positions [0 ]
341+ if indirect_pos != 0 :
342+ raise NotImplementedError (
343+ "Pallas backend: indirect gather is only supported on dim 0"
344+ )
345+ pattern = indexing_patterns [indirect_pos ]
346+ assert isinstance (pattern , IndirectGatherPattern )
347+
348+ table_bytes = tensor .numel () * tensor .dtype .itemsize
349+ if (
350+ isinstance (table_bytes , int )
351+ and table_bytes > _PALLAS_GATHER_VMEM_THRESHOLD_BYTES
352+ ):
353+ raise NotImplementedError (
354+ f"Pallas backend: indirect gather requires the full table in VMEM "
355+ f"({ table_bytes } bytes > { _PALLAS_GATHER_VMEM_THRESHOLD_BYTES } byte "
356+ f"threshold). Tile the kernel so the gathered table fits, or use a "
357+ f"different access pattern."
358+ )
359+
360+ if not tensor .dtype .is_floating_point :
361+ raise NotImplementedError (
362+ f"Pallas backend: indirect gather requires a floating-point table, "
363+ f"got { tensor .dtype } "
364+ )
365+
366+ vocab_size = tensor .shape [0 ]
367+
368+ ast_subscripts = state .ast_args [1 ]
369+ assert isinstance (ast_subscripts , list )
370+ ast_idx = ast_subscripts [indirect_pos ]
371+ assert isinstance (ast_idx , ast .AST )
372+ idx_name = state .codegen .lift (ast_idx , dce = False , prefix = "index" ).id
373+
374+ # Collect none_dims from subscript for expand_dims after the matmul
375+ none_dims : list [int ] = []
376+ for out_pos , idx in enumerate (subscript ):
377+ if idx is None :
378+ none_dims .append (out_pos )
379+
380+ jnp_dtype = CompileEnvironment .current ().backend .dtype_str (tensor .dtype )
381+ # TPU MXU requires 32-bit accumulator. For float32 tables we also need
382+ # Precision.HIGHEST to prevent MXU from truncating inputs to bfloat16
383+ # before multiply-accumulate. For half types the truncation is a no-op.
384+ needs_highest = tensor .dtype not in (torch .bfloat16 , torch .float16 )
385+ precision_arg = "precision=jax.lax.Precision.HIGHEST, " if needs_highest else ""
386+ result = expr_from_string (
387+ f"jax.lax.dot_general("
388+ f"jax.nn.one_hot({ idx_name } [...], { vocab_size } , dtype=jnp.float32), "
389+ f"{ name } [...].astype(jnp.float32), "
390+ f"(((1,), (0,)), ((), ())), "
391+ f"preferred_element_type=jnp.float32, "
392+ f"{ precision_arg } "
393+ f").astype({ jnp_dtype } )"
394+ )
395+
396+ for dim in none_dims :
397+ result = expr_from_string (
398+ f"jnp.expand_dims({{result}}, axis={ dim } )" , result = result
399+ )
400+ return result
401+
402+
307403def _pallas_tile_pattern_code (
308404 pattern : object ,
309405 idx : object ,
@@ -448,6 +544,12 @@ def _(state: CodegenState) -> None:
448544 device_fn = state .device_function
449545 device_fn .device_store_index += 1
450546 device_fn .device_memory_op_index += 1
547+ indexing_patterns = _pallas_get_indexing_patterns (state , tensor )
548+ if _pallas_indirect_gather_positions (indexing_patterns ):
549+ # TODO(pallas-scatter): emit one_hot(idx, V).T @ values
550+ raise NotImplementedError (
551+ "Pallas backend: indirect store (scatter) is not supported"
552+ )
451553 index_str , _ = _pallas_index_str (state , subscript , tensor )
452554 state .codegen .add_statement (
453555 statement_from_string (f"{ name } [{ index_str } ] = {{value}}" , value = value )
@@ -1507,6 +1609,14 @@ def _(state: CodegenState) -> ast.AST:
15071609 device_fn = state .device_function
15081610 device_fn .device_load_index += 1
15091611 device_fn .device_memory_op_index += 1
1612+
1613+ indexing_patterns = _pallas_get_indexing_patterns (state , tensor )
1614+ indirect_positions = _pallas_indirect_gather_positions (indexing_patterns )
1615+ if indirect_positions :
1616+ return _pallas_emit_gather_load (
1617+ state , tensor , subscript , indexing_patterns , indirect_positions , name
1618+ )
1619+
15101620 index_str , none_dims = _pallas_index_str (state , subscript , tensor )
15111621 result = expr_from_string (f"{ name } [{ index_str } ]" )
15121622 for dim in none_dims :
0 commit comments