Skip to content

Commit c40398c

Browse files
authored
[JAX] Size autotuned Triton grids per config (#2975)
* [JAX] Size autotuned Triton grids per config (3x perm-kernel speedup) The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32. * [JAX] Triton wrapper defaults match jax-triton (3.25ms speedup) num_warps default 32->4 and num_stages 1->3 in triton_call_lowering match Triton's own triton.Config defaults. Non-autotuned kernels (e.g. _make_chunk_sort_map_kernel) were running with 1024 threads/block, an 8x kernel slowdown. Also: tuple/callable grid assertion + comment trims. Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 4322c0a commit c40398c

2 files changed

Lines changed: 84 additions & 34 deletions

File tree

transformer_engine/jax/triton_extensions/permutation.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,13 @@ def lowering(
589589
probs_stride_token = 0
590590
probs_stride_expert = 0
591591

592-
# Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE))
593-
# Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements
592+
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
593+
# proper size. If the grid size is an overestimate it can significantly
594+
# hurt performance.
595+
def grid(meta):
596+
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
597+
594598
block_size = _get_min_block_size(_permute_kernel)
595-
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
596599

597600
# Use input_output_aliases to alias pre-zeroed buffers to outputs.
598601
# This ensures padding positions contain zeros since the kernel only writes valid positions.
@@ -997,9 +1000,13 @@ def lowering(
9971000
unpermuted_probs_stride_token = num_experts
9981001
unpermuted_probs_stride_expert = 1
9991002

1000-
# Grid - use minimum BLOCK_SIZE from autotune configs
1003+
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
1004+
# proper size. If the grid size is an overestimate it can significantly
1005+
# hurt performance.
1006+
def grid(meta):
1007+
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
1008+
10011009
block_size = _get_min_block_size(_unpermute_kernel)
1002-
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
10031010

10041011
return triton_call_lowering(
10051012
ctx,
@@ -1720,9 +1727,13 @@ def lowering(
17201727
probs_stride_token = 1
17211728
permuted_probs_stride_token = 1
17221729

1723-
# Grid - use minimum BLOCK_SIZE from autotune configs
1730+
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
1731+
# proper size. If the grid size is an overestimate it can significantly
1732+
# hurt performance.
1733+
def grid(meta):
1734+
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
1735+
17241736
block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
1725-
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
17261737

17271738
# Declare input_output_aliases so XLA knows output slot 0 is claimed by
17281739
# input 3 (output_buf). This prevents XLA from implicitly aliasing any

transformer_engine/jax/triton_extensions/utils.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,16 @@ def triton_call_lowering(
390390
ctx: MLIR lowering context
391391
kernel_fn: Triton kernel function
392392
*array_args: Input arrays (from ctx)
393-
grid: Grid dimensions (int or tuple)
393+
grid: Grid dimensions. May be either:
394+
- an int or tuple (fixed grid for every config), or
395+
- a callable ``meta -> int|tuple`` (evaluated per autotune config).
396+
397+
Use the callable form for autotuned kernels whose grid depends on
398+
``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the
399+
launch grid will not match the autotuner-selected config and the
400+
kernel will either over-launch (waste) or under-cover. ``meta`` is
401+
the merged dict ``{**constexprs, **config.kwargs}`` for the chosen
402+
config — the same convention as jax-triton's ``triton_call``.
394403
input_output_aliases: Mapping of input to output aliases
395404
constexprs: Compile-time constants for the kernel. This includes both
396405
tl.constexpr arguments AND scalar runtime arguments (like
@@ -404,13 +413,12 @@ def triton_call_lowering(
404413
def lowering(ctx, x, *, block_size):
405414
from ..triton_extensions import triton_call_lowering
406415
n = ctx.avals_in[0].size
416+
417+
def grid(meta):
418+
return (triton.cdiv(n, meta["BLOCK_SIZE"]),)
419+
407420
return triton_call_lowering(
408-
ctx, my_kernel, x,
409-
grid=(triton.cdiv(n, block_size),),
410-
constexprs={
411-
"n_elements": n, # scalar arg (not tl.constexpr in kernel)
412-
"BLOCK_SIZE": block_size, # tl.constexpr arg
413-
},
421+
ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n},
414422
)
415423
"""
416424
# Get compute capability using gpu_triton
@@ -431,22 +439,39 @@ def lowering(ctx, x, *, block_size):
431439
tensor_arg_names = [n for n in arg_names if n not in constexpr_names]
432440
signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)}
433441

434-
# Normalize grid to 3D
435-
if isinstance(grid, int):
436-
grid_tuple = (grid, 1, 1)
437-
elif len(grid) == 1:
438-
grid_tuple = (grid[0], 1, 1)
439-
elif len(grid) == 2:
440-
grid_tuple = (grid[0], grid[1], 1)
441-
else:
442-
grid_tuple = grid[:3]
442+
assert callable(grid) or isinstance(grid, tuple), (
443+
"Argument 'grid' must be a tuple or a callable but received: "
444+
f"type={type(grid)}, value={grid}"
445+
)
443446

444-
# Default values for the kernel
447+
# Normalize grid to 3D. When `grid` is a callable, defer evaluation until
448+
# we know the per-config meta (so each autotune config gets its own grid,
449+
# matching jax-triton's behavior).
450+
def _normalize_grid(grid_tuple):
451+
if isinstance(grid_tuple, int):
452+
return (grid_tuple, 1, 1)
453+
if len(grid_tuple) == 1:
454+
return (grid_tuple[0], 1, 1)
455+
if len(grid_tuple) == 2:
456+
return (grid_tuple[0], grid_tuple[1], 1)
457+
return tuple(grid_tuple[:3])
458+
459+
grid_callable = grid if callable(grid) else None
460+
if grid_callable is None:
461+
grid_tuple = _normalize_grid(grid)
462+
else:
463+
grid_tuple = None # evaluated per-config below
464+
465+
# Default kernel launch parameters. These apply to non-autotuned kernels
466+
# and as a fallback when an autotuned config doesn't specify them. Values
467+
# match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3,
468+
# num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a
469+
# larger default (e.g. num_warps=32) over-provisions threads per block,
470+
# which slashes SM occupancy on non-autotuned kernels — measured as an 8×
471+
# slowdown on `_make_chunk_sort_map_kernel` vs jax-triton.
445472
actual_kernel_fn = kernel_fn
446-
num_warps = 32
447-
num_stages = (
448-
1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas
449-
)
473+
num_warps = 4
474+
num_stages = 3
450475
num_ctas = 1
451476
kernel_constexprs = constexprs if constexprs is not None else {}
452477

@@ -510,11 +535,18 @@ def lowering(ctx, x, *, block_size):
510535
for _ in list(ctx.avals_in) + list(ctx.avals_out):
511536
config_params.append(gpu_triton.create_array_parameter(0, 16))
512537

538+
# Per-config grid: evaluate `grid(meta)` if grid is a callable so
539+
# the launch shape matches this config's BLOCK_SIZE (etc.).
540+
if grid_callable is not None:
541+
config_grid = _normalize_grid(grid_callable(config_constexprs))
542+
else:
543+
config_grid = grid_tuple
544+
513545
config_call = gpu_triton.TritonKernelCall(
514546
config_kernel,
515-
grid_tuple[0],
516-
grid_tuple[1],
517-
grid_tuple[2],
547+
config_grid[0],
548+
config_grid[1],
549+
config_grid[2],
518550
config_params,
519551
)
520552

@@ -571,11 +603,18 @@ def lowering(ctx, x, *, block_size):
571603
for _ in list(ctx.avals_in) + list(ctx.avals_out):
572604
kernel_params.append(gpu_triton.create_array_parameter(0, 16))
573605

606+
# Non-autotuned dispatch: evaluate `grid(meta)` once with the merged
607+
# constexprs (which already reflect the single config we'll launch).
608+
if grid_callable is not None:
609+
single_grid = _normalize_grid(grid_callable(kernel_constexprs))
610+
else:
611+
single_grid = grid_tuple
612+
574613
kernel_call = gpu_triton.TritonKernelCall(
575614
kernel,
576-
grid_tuple[0],
577-
grid_tuple[1],
578-
grid_tuple[2],
615+
single_grid[0],
616+
single_grid[1],
617+
single_grid[2],
579618
kernel_params,
580619
)
581620

0 commit comments

Comments
 (0)