@@ -390,7 +390,16 @@ def triton_call_lowering(
390390 ctx: MLIR lowering context
391391 kernel_fn: Triton kernel function
392392 *array_args: Input arrays (from ctx)
393- grid: Grid dimensions (int or tuple)
393+ grid: Grid dimensions. May be either:
394+ - an int or tuple (fixed grid for every config), or
395+ - a callable ``meta -> int|tuple`` (evaluated per autotune config).
396+
397+ Use the callable form for autotuned kernels whose grid depends on
398+ ``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the
399+ launch grid will not match the autotuner-selected config and the
400+ kernel will either over-launch (waste) or under-cover. ``meta`` is
401+ the merged dict ``{**constexprs, **config.kwargs}`` for the chosen
402+ config — the same convention as jax-triton's ``triton_call``.
394403 input_output_aliases: Mapping of input to output aliases
395404 constexprs: Compile-time constants for the kernel. This includes both
396405 tl.constexpr arguments AND scalar runtime arguments (like
@@ -404,13 +413,12 @@ def triton_call_lowering(
404413 def lowering(ctx, x, *, block_size):
405414 from ..triton_extensions import triton_call_lowering
406415 n = ctx.avals_in[0].size
416+
417+ def grid(meta):
418+ return (triton.cdiv(n, meta["BLOCK_SIZE"]),)
419+
407420 return triton_call_lowering(
408- ctx, my_kernel, x,
409- grid=(triton.cdiv(n, block_size),),
410- constexprs={
411- "n_elements": n, # scalar arg (not tl.constexpr in kernel)
412- "BLOCK_SIZE": block_size, # tl.constexpr arg
413- },
421+ ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n},
414422 )
415423 """
416424 # Get compute capability using gpu_triton
@@ -431,22 +439,39 @@ def lowering(ctx, x, *, block_size):
431439 tensor_arg_names = [n for n in arg_names if n not in constexpr_names ]
432440 signature = {n : get_triton_dtype (a ) for n , a in zip (tensor_arg_names , all_avals )}
433441
434- # Normalize grid to 3D
435- if isinstance (grid , int ):
436- grid_tuple = (grid , 1 , 1 )
437- elif len (grid ) == 1 :
438- grid_tuple = (grid [0 ], 1 , 1 )
439- elif len (grid ) == 2 :
440- grid_tuple = (grid [0 ], grid [1 ], 1 )
441- else :
442- grid_tuple = grid [:3 ]
442+ assert callable (grid ) or isinstance (grid , tuple ), (
443+ "Argument 'grid' must be a tuple or a callable but received: "
444+ f"type={ type (grid )} , value={ grid } "
445+ )
443446
444- # Default values for the kernel
447+ # Normalize grid to 3D. When `grid` is a callable, defer evaluation until
448+ # we know the per-config meta (so each autotune config gets its own grid,
449+ # matching jax-triton's behavior).
450+ def _normalize_grid (grid_tuple ):
451+ if isinstance (grid_tuple , int ):
452+ return (grid_tuple , 1 , 1 )
453+ if len (grid_tuple ) == 1 :
454+ return (grid_tuple [0 ], 1 , 1 )
455+ if len (grid_tuple ) == 2 :
456+ return (grid_tuple [0 ], grid_tuple [1 ], 1 )
457+ return tuple (grid_tuple [:3 ])
458+
459+ grid_callable = grid if callable (grid ) else None
460+ if grid_callable is None :
461+ grid_tuple = _normalize_grid (grid )
462+ else :
463+ grid_tuple = None # evaluated per-config below
464+
465+ # Default kernel launch parameters. These apply to non-autotuned kernels
466+ # and as a fallback when an autotuned config doesn't specify them. Values
467+ # match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3,
468+ # num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a
469+ # larger default (e.g. num_warps=32) over-provisions threads per block,
470+ # which slashes SM occupancy on non-autotuned kernels — measured as an 8×
471+ # slowdown on `_make_chunk_sort_map_kernel` vs jax-triton.
445472 actual_kernel_fn = kernel_fn
446- num_warps = 32
447- num_stages = (
448- 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas
449- )
473+ num_warps = 4
474+ num_stages = 3
450475 num_ctas = 1
451476 kernel_constexprs = constexprs if constexprs is not None else {}
452477
@@ -510,11 +535,18 @@ def lowering(ctx, x, *, block_size):
510535 for _ in list (ctx .avals_in ) + list (ctx .avals_out ):
511536 config_params .append (gpu_triton .create_array_parameter (0 , 16 ))
512537
538+ # Per-config grid: evaluate `grid(meta)` if grid is a callable so
539+ # the launch shape matches this config's BLOCK_SIZE (etc.).
540+ if grid_callable is not None :
541+ config_grid = _normalize_grid (grid_callable (config_constexprs ))
542+ else :
543+ config_grid = grid_tuple
544+
513545 config_call = gpu_triton .TritonKernelCall (
514546 config_kernel ,
515- grid_tuple [0 ],
516- grid_tuple [1 ],
517- grid_tuple [2 ],
547+ config_grid [0 ],
548+ config_grid [1 ],
549+ config_grid [2 ],
518550 config_params ,
519551 )
520552
@@ -571,11 +603,18 @@ def lowering(ctx, x, *, block_size):
571603 for _ in list (ctx .avals_in ) + list (ctx .avals_out ):
572604 kernel_params .append (gpu_triton .create_array_parameter (0 , 16 ))
573605
606+ # Non-autotuned dispatch: evaluate `grid(meta)` once with the merged
607+ # constexprs (which already reflect the single config we'll launch).
608+ if grid_callable is not None :
609+ single_grid = _normalize_grid (grid_callable (kernel_constexprs ))
610+ else :
611+ single_grid = grid_tuple
612+
574613 kernel_call = gpu_triton .TritonKernelCall (
575614 kernel ,
576- grid_tuple [0 ],
577- grid_tuple [1 ],
578- grid_tuple [2 ],
615+ single_grid [0 ],
616+ single_grid [1 ],
617+ single_grid [2 ],
579618 kernel_params ,
580619 )
581620
0 commit comments