Skip to content

Commit ac80716

Browse files
committed
docs(upstream): file PR #2143 (Metal builtins direct) + close Candidate #1 as false alarm
Two parallel agent investigations of follow-up Path C upstream candidates: CANDIDATE #1 — closed as false alarm docs/upstream/tilelang_metal_inline_kernel_body/ The 'Apple MSL forbids threadgroup allocs in non-kernel functions' bug DOES NOT exist in jorgecurious/tilelang:metal-gemm-upstream-rebase. Metal codegen already emits prim_func bodies directly inside `kernel void` without any inline-void wrapper. xcrun --sdk macosx metal -c compiles T.alloc_shared and T.alloc_fragment+T.gemm prim_funcs cleanly without cppmega post-processing. Verified .air artifacts checked in. The cppmega `_inline_tilelang_kernel_body` workaround is real but solves a different problem: it adapts TileLang's complete `kernel void` MSL for MLX's mx.fast.metal_kernel API (MLX generates its own kernel signature and doesn't accept a pre-baked one). This is an MLX/TileLang integration concern, not a TileLang codegen bug. No upstream PR needed. CANDIDATE #2 — PR FILED at tile-ai/tilelang#2143 docs/upstream/tilelang_metal_emit_metal_builtins/ The 'CUDA-style threadIdx/blockIdx aliases' bug IS real. TileLang's metal codegen emitted `uint3 blockIdx [[threadgroup_position_in_grid]],` and `((int)threadIdx.x)` references in body. Fix: emit the Metal builtin names directly as kernel-launch parameters and body references. Patches: - 0001-metal-emit-builtins-directly.patch (75 lines, +36/-7 in src/target/codegen_metal.cc) - 0002-tvm-metal-emit-builtins-directly.patch (69 lines, +30/-7 in 3rdparty/tvm/src/target/source/codegen_metal.cc — for TileLang/tvm submodule, separate companion PR not yet filed) Verified: TileLang ninja -j8 builds clean. Smoke test on T.copy prim_func shows alias gone, Metal builtin used directly. cppmega Path C tests (test_tilelang_mamba3_path_c.py 11/11 pass) — regex helpers in _msl_transform.py become no-ops with this fix, output identical. PR: tile-ai/tilelang#2143
1 parent efb6529 commit ac80716

6 files changed

Lines changed: 544 additions & 0 deletions

File tree

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
diff --git a/src/target/codegen_metal.cc b/src/target/codegen_metal.cc
2+
index faa43a35..e924ef99 100644
3+
--- a/src/target/codegen_metal.cc
4+
+++ b/src/target/codegen_metal.cc
5+
@@ -153,8 +153,15 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
6+
decl_stream << "};\n\n";
7+
}
8+
// Setup the thread group info.
9+
+ // Reserve the CUDA-style alias names so user code or downstream passes
10+
+ // cannot accidentally collide with them, even though the kernel itself
11+
+ // emits Metal builtin names directly (no `blockIdx`/`threadIdx` aliases).
12+
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
13+
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
14+
+ ICHECK_EQ(name_supply_->FreshName("threadgroup_position_in_grid"),
15+
+ "threadgroup_position_in_grid");
16+
+ ICHECK_EQ(name_supply_->FreshName("thread_position_in_threadgroup"),
17+
+ "thread_position_in_threadgroup");
18+
int work_dim = 0;
19+
auto launch_params =
20+
func->GetAttr<ffi::Array<ffi::String>>(tir::attr::kKernelLaunchParams)
21+
@@ -167,13 +174,22 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
22+
}
23+
24+
if (work_dim != 0) {
25+
- // use ushort by default for now
26+
+ // Emit Metal builtin names directly as the kernel parameter identifiers
27+
+ // rather than using CUDA-style `blockIdx`/`threadIdx` aliases. This means
28+
+ // body references are emitted as e.g. `threadgroup_position_in_grid.x`
29+
+ // instead of `blockIdx.x`, which:
30+
+ // - matches Apple's MSL convention,
31+
+ // - eliminates a redundant naming layer that downstream MSL passes had
32+
+ // to undo with regex-based string substitution (see cppmega.mlx
33+
+ // `_msl_transform.py::_canonicalize_tilelang_builtin_aliases`),
34+
+ // - removes intermediate dead-alias declarations that callers had to
35+
+ // strip manually.
36+
stream << " ";
37+
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
38+
- stream << " blockIdx [[threadgroup_position_in_grid]],\n";
39+
+ stream << " threadgroup_position_in_grid [[threadgroup_position_in_grid]],\n";
40+
stream << " ";
41+
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
42+
- stream << " threadIdx [[thread_position_in_threadgroup]]\n";
43+
+ stream << " thread_position_in_threadgroup [[thread_position_in_threadgroup]]\n";
44+
}
45+
thread_work_dim_ = work_dim;
46+
47+
@@ -188,11 +204,24 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
48+
49+
void CodeGenTileLangMetal::BindThreadIndex(const IterVar &iv) {
50+
ICHECK(!var_idmap_.count(iv->var.get()));
51+
- // if we only have threadIdx.x
52+
- // metal will directly print as threadIdx
53+
+ // The thread_tag is the CUDA-style name (e.g. "threadIdx.x", "blockIdx.y").
54+
+ // Translate to the Metal builtin reference so emitted body references
55+
+ // resolve directly against the kernel parameters declared in AddFunction
56+
+ // (which now use the Metal builtin names verbatim instead of the
57+
+ // blockIdx/threadIdx aliases). The .x/.y/.z suffix is preserved.
58+
std::string vname = iv->thread_tag;
59+
- if (thread_work_dim_ <= 1) {
60+
- vname = vname.substr(0, iv->thread_tag.length() - 2);
61+
+ std::string axis;
62+
+ if (vname.length() >= 2 && vname[vname.length() - 2] == '.') {
63+
+ axis = vname.substr(vname.length() - 2); // ".x" / ".y" / ".z"
64+
+ vname = vname.substr(0, vname.length() - 2);
65+
+ }
66+
+ if (vname == "threadIdx") {
67+
+ vname = "thread_position_in_threadgroup";
68+
+ } else if (vname == "blockIdx") {
69+
+ vname = "threadgroup_position_in_grid";
70+
+ }
71+
+ if (thread_work_dim_ > 1) {
72+
+ vname += axis;
73+
}
74+
var_idmap_[iv->var.get()] =
75+
CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
2+
index 0104277..2645c8e 100644
3+
--- a/src/target/source/codegen_metal.cc
4+
+++ b/src/target/source/codegen_metal.cc
5+
@@ -146,8 +146,15 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
6+
decl_stream << "};\n\n";
7+
}
8+
// Setup the thread group info.
9+
+ // Reserve the CUDA-style alias names so user code or downstream passes
10+
+ // cannot accidentally collide with them, even though the kernel itself
11+
+ // emits Metal builtin names directly (no `blockIdx`/`threadIdx` aliases).
12+
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
13+
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
14+
+ ICHECK_EQ(name_supply_->FreshName("threadgroup_position_in_grid"),
15+
+ "threadgroup_position_in_grid");
16+
+ ICHECK_EQ(name_supply_->FreshName("thread_position_in_threadgroup"),
17+
+ "thread_position_in_threadgroup");
18+
int work_dim = 0;
19+
auto launch_params =
20+
func->GetAttr<ffi::Array<ffi::String>>(tir::attr::kKernelLaunchParams).value();
21+
@@ -159,13 +166,16 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
22+
}
23+
24+
if (work_dim != 0) {
25+
- // use ushort by default for now
26+
+ // Emit Metal builtin names directly as the kernel parameter identifiers
27+
+ // rather than using CUDA-style `blockIdx`/`threadIdx` aliases. This keeps
28+
+ // body references aligned with Apple's MSL convention and avoids forcing
29+
+ // downstream passes to canonicalize the alias back to the Metal builtin.
30+
stream << " ";
31+
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
32+
- stream << " blockIdx [[threadgroup_position_in_grid]],\n";
33+
+ stream << " threadgroup_position_in_grid [[threadgroup_position_in_grid]],\n";
34+
stream << " ";
35+
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
36+
- stream << " threadIdx [[thread_position_in_threadgroup]]\n";
37+
+ stream << " thread_position_in_threadgroup [[thread_position_in_threadgroup]]\n";
38+
}
39+
thread_work_dim_ = work_dim;
40+
41+
@@ -180,11 +190,24 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
42+
43+
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
44+
ICHECK(!var_idmap_.count(iv->var.get()));
45+
- // if we only have threadIdx.x
46+
- // metal will directly print as threadIdx
47+
+ // The thread_tag is the CUDA-style name (e.g. "threadIdx.x", "blockIdx.y").
48+
+ // Translate to the Metal builtin reference so emitted body references
49+
+ // resolve directly against the kernel parameters declared in AddFunction
50+
+ // (which now use the Metal builtin names verbatim instead of the
51+
+ // blockIdx/threadIdx aliases). The .x/.y/.z suffix is preserved.
52+
std::string vname = iv->thread_tag;
53+
- if (thread_work_dim_ <= 1) {
54+
- vname = vname.substr(0, iv->thread_tag.length() - 2);
55+
+ std::string axis;
56+
+ if (vname.length() >= 2 && vname[vname.length() - 2] == '.') {
57+
+ axis = vname.substr(vname.length() - 2); // ".x" / ".y" / ".z"
58+
+ vname = vname.substr(0, vname.length() - 2);
59+
+ }
60+
+ if (vname == "threadIdx") {
61+
+ vname = "thread_position_in_threadgroup";
62+
+ } else if (vname == "blockIdx") {
63+
+ vname = "threadgroup_position_in_grid";
64+
+ }
65+
+ if (thread_work_dim_ > 1) {
66+
+ vname += axis;
67+
}
68+
var_idmap_[iv->var.get()] =
69+
CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# TileLang Metal codegen: emit Metal builtins directly
2+
3+
Filed PR: https://github.com/tile-ai/tilelang/pull/2143
4+
Branch: `apstenku123:cppmega/metal-emit-builtins-directly`
5+
Stacks on: tile-ai/tilelang#2130 (jorgecurious metal-gemm-upstream-rebase)
6+
7+
## What this fixes
8+
9+
TileLang's Metal codegen names the thread/block kernel-launch parameters
10+
using the CUDA-style identifiers `blockIdx` and `threadIdx`. The MSL output
11+
of `lower(prim_func, target='metal')` therefore looks like:
12+
13+
```cpp
14+
kernel void smoke_kernel(
15+
device const half4* A [[ buffer(0) ]],
16+
device half4* C [[ buffer(1) ]],
17+
uint3 blockIdx [[threadgroup_position_in_grid]],
18+
uint3 threadIdx [[thread_position_in_threadgroup]]
19+
) {
20+
C[((((int)threadIdx.x) * 4) / 4)] = A[((((int)threadIdx.x) * 4) / 4)];
21+
}
22+
```
23+
24+
The named parameters mirror CUDA's `blockIdx.x`/`threadIdx.x`, but downstream
25+
consumers that inline the body of `kernel void` into another kernel (e.g.
26+
the cppmega.mlx Path C ports that splice TileLang-emitted bodies into
27+
`mx.fast.metal_kernel` `source=` strings) end up having to:
28+
29+
* Inject `uint3 blockIdx = threadgroup_position_in_grid;` and
30+
`uint3 threadIdx = thread_position_in_threadgroup;` shims so the body's
31+
references still bind, then
32+
* Regex-substitute every `((int)threadIdx.x)` etc. back to
33+
`((int)thread_position_in_threadgroup.x)`, then
34+
* Drop the now-dead alias declarations.
35+
36+
See the canonicalization helpers in cppmega.mlx
37+
(`cppmega_mlx/nn/_tilelang/_msl_transform.py`):
38+
39+
* `_metal_builtin_for_tilelang_alias`
40+
* `_rewrite_tilelang_builtin_axis`
41+
* `_rewrite_tilelang_builtin_axis_cast`
42+
* `_canonicalize_tilelang_builtin_aliases`
43+
* `_drop_alias_decl_if_unused`
44+
45+
The whole chain is pure overhead: every consumer either lives with the
46+
alias or post-processes it back to the Metal builtin.
47+
48+
## What this PR does
49+
50+
In TileLang's Metal codegen, the thread/block kernel parameters are now
51+
declared using the Metal builtin identifiers themselves:
52+
53+
```cpp
54+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
55+
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]
56+
```
57+
58+
`BindThreadIndex` translates the CUDA-style `IterVar::thread_tag`
59+
(`"blockIdx.x"`, `"threadIdx.y"`, ...) to the matching Metal builtin
60+
reference (`threadgroup_position_in_grid.x`, ...) before recording it in
61+
`var_idmap_`. The body therefore emits `((int)threadgroup_position_in_grid.x)`
62+
directly. The `name_supply_` reservation also keeps the legacy
63+
`blockIdx`/`threadIdx` names blocked so the rest of the kernel cannot
64+
collide with them.
65+
66+
### Before / after MSL
67+
68+
Before:
69+
70+
```cpp
71+
kernel void smoke_kernel(
72+
device const half4* A [[ buffer(0) ]],
73+
device half4* C [[ buffer(1) ]],
74+
uint3 blockIdx [[threadgroup_position_in_grid]],
75+
uint3 threadIdx [[thread_position_in_threadgroup]]
76+
) {
77+
C[((((int)threadIdx.x) * 4) / 4)] = A[((((int)threadIdx.x) * 4) / 4)];
78+
}
79+
```
80+
81+
After:
82+
83+
```cpp
84+
kernel void smoke_kernel(
85+
device const half4* A [[ buffer(0) ]],
86+
device half4* C [[ buffer(1) ]],
87+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
88+
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]
89+
) {
90+
C[((((int)thread_position_in_threadgroup.x) * 4) / 4)] = A[((((int)thread_position_in_threadgroup.x) * 4) / 4)];
91+
}
92+
```
93+
94+
## Files in this directory
95+
96+
* `0001-metal-emit-builtins-directly.patch` - the TileLang half. Touches
97+
`src/target/codegen_metal.cc`. ~36 LOC.
98+
* `0002-tvm-metal-emit-builtins-directly.patch` - the vendored TVM half.
99+
Touches `3rdparty/tvm/src/target/source/codegen_metal.cc`. ~30 LOC.
100+
101+
The TileLang half is filed as one PR to `tile-ai/tilelang`. The TVM half
102+
needs to land in `TileLang/tvm` (the vendored fork) before a TileLang
103+
release that bumps the submodule.
104+
105+
## Stacking
106+
107+
This PR depends on `jorgecurious/tilelang:metal-gemm-upstream-rebase`
108+
(PR tile-ai/tilelang#2130) for the simdgroup-store hardening that the
109+
Path C ports rely on. The diff applies cleanly on top of that branch.
110+
111+
## Verification
112+
113+
* Build: TileLang `ninja -j8` succeeds against the patched
114+
`src/target/codegen_metal.cc` and submodule `codegen_metal.cc`.
115+
* Smoke test: `lower(prim_func, target='metal')` no longer emits
116+
`int blockIdx_x = ...;`-style aliases, the Metal builtin name appears
117+
directly in body references.
118+
* cppmega Path C tests still pass: the regex helpers in
119+
`_msl_transform.py` become no-ops because the emitted MSL already uses
120+
the Metal builtin names; the helpers are kept (idempotent) so the
121+
fallback works against unpatched TileLang releases as well.
122+
123+
## Risk
124+
125+
* Limited to the `tilelang_metal` and `metal` codegen paths in TileLang
126+
and the vendored TVM fork respectively. CUDA, ROCm, OpenCL, WebGPU,
127+
and CPU codegen are unaffected.
128+
* MSL parameter names are user-chosen identifiers; renaming the
129+
parameter does not change semantics as the `[[threadgroup_position_in_grid]]`
130+
attribute is what binds the value. Apple's MSL spec permits the
131+
parameter identifier to match the attribute name.
132+
* `name_supply_` reserves both the new Metal builtin names and the
133+
legacy `blockIdx`/`threadIdx` names so a future user-defined symbol
134+
cannot collide, and the existing assertion-based contract on the
135+
reservation order is preserved.

0 commit comments

Comments
 (0)