Skip to content

Commit 9a0d679

Browse files
committed
docs(upstream): split FP8 patches + 4 more PRs filed (PRs #2144/#2145, #37/#38/#39)
Three parallel agents completed the supermodule/submodule split filing: 1. tilelang_metal_fp8 (storage-only FP8 emulation) split: - 0001-tilelang-metal-fp8-storage-only.patch — supermodule half (235 lines) - 0002-tvm-metal-fp8-storage-only.patch — TVM-mirror half (260 lines, prefix stripped) - PR tile-ai/tilelang#2144 (supermodule, stacks on PR #2130) - PR tile-ai/tvm#38 (TVM mirror, base tilelang_main @ 0e15b274) 2. tilelang_metal_fp8_vector (vector cast lanes 2/3/4) split: - 0001-tilelang-metal-fp8-vector-cast.patch — supermodule half (148 lines) - 0002-tvm-metal-fp8-vector-cast.patch — TVM-mirror half (151 lines) - PR tile-ai/tilelang#2145 (supermodule, depends on #2144) - PR tile-ai/tvm#39 (TVM mirror, depends on #38) 3. PR #2143 TVM-mirror companion: - PR tile-ai/tvm#37 — already filed, README updated to link both halves Total filed today: 11 PRs across 3 repos - 1 ml-explore/mlx (#3476) - 1 apache/tvm (#19504) - 6 tile-ai/tilelang (#2139, #2140, #2141, #2142, #2143 super, #2144 super, #2145 super) - 3 tile-ai/tvm (#37, #38, #39 — TVM-mirror companions) PR #2142 (T.fp8_scaled_matmul) has no TVM-mirror companion needed — verified the patch only touches supermodule files. All splits round-trip clean (apply forward + reverse) on their respective bases. README files in each docs/upstream/<dir>/ updated with PR URLs and dependency-chain diagrams. Note: TileLang/tvm redirects to tile-ai/tvm server-side (canonical org slug). All TVM-mirror PRs land at tile-ai/tvm/pull/N URLs.
1 parent ac80716 commit 9a0d679

9 files changed

Lines changed: 733 additions & 533 deletions

File tree

docs/upstream/_filed_prs_2026_05_04.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ GitHub user: `apstenku123` (David Gornshtein, davidgornshtein@gmail.com)
1212
| 4 | `tile-ai/tilelang` | [#2140](https://github.com/tile-ai/tilelang/pull/2140) | [Metal] route FP8-input T.gemm to scalar fallback | filed against `main`, **stacks on PR #2130** |
1313
| 5 | `tile-ai/tilelang` | [#2141](https://github.com/tile-ai/tilelang/pull/2141) | [Metal] thread stage dim through T.access_ptr for T.Pipelined num_stages>1 | filed against `main`, **stacks on PR #2130** |
1414
| 6 | `tile-ai/tilelang` | [#2142](https://github.com/tile-ai/tilelang/pull/2142) | tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering | filed against `main`, **stacks on PR #2130** |
15+
| 7 | `tile-ai/tilelang` | [#2143](https://github.com/tile-ai/tilelang/pull/2143) | [Metal] emit Metal builtins directly instead of CUDA-style threadIdx/blockIdx aliases (supermodule half) | filed against `main`, **stacks on PR #2130** |
16+
| 8 | `tile-ai/tvm` (TileLang/tvm) | [#37](https://github.com/tile-ai/tvm/pull/37) | [Metal] emit Metal builtins directly instead of CUDA-style threadIdx/blockIdx aliases (TVM-mirror half) | filed against `main`, paired with tile-ai/tilelang#2143 |
1517

1618
## Deferred (split needed: tilelang supermodule + TileLang/tvm submodule)
1719

docs/upstream/tilelang_metal_emit_metal_builtins/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# TileLang Metal codegen: emit Metal builtins directly
22

3-
Filed PR: https://github.com/tile-ai/tilelang/pull/2143
4-
Branch: `apstenku123:cppmega/metal-emit-builtins-directly`
3+
Filed PRs:
4+
5+
* Tilelang supermodule half: https://github.com/tile-ai/tilelang/pull/2143
6+
* TileLang/tvm submodule mirror: https://github.com/tile-ai/tvm/pull/37
7+
8+
Branch: `apstenku123:cppmega/metal-emit-builtins-directly` (same name on both forks)
59
Stacks on: tile-ai/tilelang#2130 (jorgecurious metal-gemm-upstream-rebase)
610

711
## What this fixes
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
diff --git a/src/target/codegen_metal.cc b/src/target/codegen_metal.cc
2+
index faa43a35..025a98e2 100644
3+
--- a/src/target/codegen_metal.cc
4+
+++ b/src/target/codegen_metal.cc
5+
@@ -57,6 +57,98 @@ CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) {
6+
<< "};\n\n";
7+
}
8+
9+
+// Inline MSL helpers for storage-only FP8 emulation (e4m3 / e5m2).
10+
+// Apple Silicon (M4 Max and earlier; M5 NAX is FP16/INT8 only) has NO native
11+
+// FP8 ALU support, so FP8 is realised as `uchar` storage with explicit
12+
+// dequantize-on-load / quantize-on-store. The helpers mirror the IEEE 754
13+
+// derived encoding from the OFP8 spec (E4M3 with finite-only encoding, E5M2
14+
+// IEEE-style with NaN/Inf).
15+
+void CodeGenTileLangMetal::PrintFP8Prelude(std::ostream &os) {
16+
+ os <<
17+
+ "// FP8 storage-only emulation helpers (MSL has no native float8 type).\n"
18+
+ "// See OCP \"OFP8 Formats for Deep Learning\" v1.0 spec.\n"
19+
+ "inline half __tvm_fp8_e4m3_to_half(uchar x) {\n"
20+
+ " ushort sign = (ushort)(x & 0x80) << 8;\n"
21+
+ " ushort mant = (ushort)(x & 0x07);\n"
22+
+ " ushort exp = (ushort)((x >> 3) & 0x0F);\n"
23+
+ " ushort h;\n"
24+
+ " if (exp == 0) {\n"
25+
+ " if (mant == 0) {\n"
26+
+ " h = sign;\n"
27+
+ " } else {\n"
28+
+ " // subnormal: e4m3 value = mant * 2^-9. After shifting the\n"
29+
+ " // mantissa so the leading 1 hits bit 2 (0x4), the half\n"
30+
+ " // biased exponent is (e + 7), not (e + 8).\n"
31+
+ " ushort m = mant;\n"
32+
+ " ushort e = 1;\n"
33+
+ " while ((m & 0x4) == 0) { m <<= 1; e -= 1; }\n"
34+
+ " m &= 0x3;\n"
35+
+ " h = (ushort)(sign | ((ushort)(e + 7) << 10) | (ushort)(m << 8));\n"
36+
+ " }\n"
37+
+ " } else if (exp == 0x0F && mant == 0x07) {\n"
38+
+ " h = (ushort)(sign | 0x7E00);\n"
39+
+ " } else {\n"
40+
+ " h = (ushort)(sign | ((ushort)(exp + 8) << 10) | (ushort)(mant << 7));\n"
41+
+ " }\n"
42+
+ " return as_type<half>(h);\n"
43+
+ "}\n"
44+
+ "inline half __tvm_fp8_e5m2_to_half(uchar x) {\n"
45+
+ " ushort h = ((ushort)x) << 8;\n"
46+
+ " return as_type<half>(h);\n"
47+
+ "}\n"
48+
+ "inline uchar __tvm_half_to_fp8_e4m3(half v) {\n"
49+
+ " ushort h = as_type<ushort>(v);\n"
50+
+ " ushort sign = (h >> 8) & 0x80;\n"
51+
+ " short he = (short)((h >> 10) & 0x1F);\n"
52+
+ " ushort hm = h & 0x3FF;\n"
53+
+ " if (he == 0x1F) {\n"
54+
+ " return (uchar)(sign | 0x7F);\n"
55+
+ " }\n"
56+
+ " short e = he - 8;\n"
57+
+ " if (e >= 0x0F) {\n"
58+
+ " return (uchar)(sign | 0x7E);\n"
59+
+ " }\n"
60+
+ " if (e <= 0) {\n"
61+
+ " if (e < -3) return (uchar)sign;\n"
62+
+ " ushort m = hm | 0x400;\n"
63+
+ " ushort shift = (ushort)(7 + 1 - e);\n"
64+
+ " ushort round_bit = (ushort)1 << (shift - 1);\n"
65+
+ " ushort sticky = m & (round_bit - 1);\n"
66+
+ " ushort q = m >> shift;\n"
67+
+ " ushort rem = m & ((round_bit << 1) - 1);\n"
68+
+ " if (rem > round_bit || (rem == round_bit && (q & 1))) q += 1;\n"
69+
+ " (void)sticky;\n"
70+
+ " return (uchar)(sign | (q & 0x7F));\n"
71+
+ " }\n"
72+
+ " ushort q = hm >> 7;\n"
73+
+ " ushort rem = hm & 0x7F;\n"
74+
+ " if (rem > 0x40 || (rem == 0x40 && (q & 1))) {\n"
75+
+ " q += 1;\n"
76+
+ " if (q == 0x08) { q = 0; e += 1; }\n"
77+
+ " if (e >= 0x0F) return (uchar)(sign | 0x7E);\n"
78+
+ " }\n"
79+
+ " return (uchar)(sign | (ushort)(e << 3) | (q & 0x07));\n"
80+
+ "}\n"
81+
+ "inline uchar __tvm_half_to_fp8_e5m2(half v) {\n"
82+
+ " ushort h = as_type<ushort>(v);\n"
83+
+ " ushort sign = h & 0x8000;\n"
84+
+ " ushort exp = (h >> 10) & 0x1F;\n"
85+
+ " ushort mant = h & 0x3FF;\n"
86+
+ " if (exp == 0x1F) {\n"
87+
+ " if (mant != 0) return (uchar)((sign >> 8) | 0x7E);\n"
88+
+ " return (uchar)((sign >> 8) | 0x7C);\n"
89+
+ " }\n"
90+
+ " ushort q = mant >> 8;\n"
91+
+ " ushort rem = mant & 0xFF;\n"
92+
+ " if (rem > 0x80 || (rem == 0x80 && (q & 1))) {\n"
93+
+ " q += 1;\n"
94+
+ " if (q == 0x4) { q = 0; exp += 1; }\n"
95+
+ " if (exp == 0x1F) return (uchar)((sign >> 8) | 0x7C);\n"
96+
+ " }\n"
97+
+ " return (uchar)((sign >> 8) | (uchar)(exp << 2) | (uchar)(q & 0x3));\n"
98+
+ "}\n\n";
99+
+}
100+
+
101+
void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
102+
const PrimFunc &func) {
103+
// NOTE: There is no inter-function calls among Metal kernels.
104+
@@ -275,6 +367,27 @@ void CodeGenTileLangMetal::PrintType(DataType t,
105+
} else if (t.is_bfloat16()) {
106+
os << "bfloat";
107+
return;
108+
+ } else if (t.is_float8()) {
109+
+ // FP8 is storage-only on Metal: print as `uchar`/`ucharN` and emit explicit
110+
+ // dequantize/quantize helpers via the FP8 prelude. Caller-side casts must
111+
+ // route through __tvm_fp8_*_to_half / __tvm_half_to_fp8_*.
112+
+ enable_fp8_ = true;
113+
+ if (lanes == 1) {
114+
+ os << "uchar";
115+
+ return;
116+
+ }
117+
+ if (lanes >= 2 && lanes <= 4) {
118+
+ os << "uchar" << lanes;
119+
+ return;
120+
+ }
121+
+ if (lanes == 8) {
122+
+ os << "uint2";
123+
+ return;
124+
+ }
125+
+ if (lanes == 16) {
126+
+ os << "uint4";
127+
+ return;
128+
+ }
129+
}
130+
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
131+
}
132+
@@ -517,6 +630,73 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
133+
}
134+
}
135+
136+
+void CodeGenTileLangMetal::VisitExpr_(const CastNode *op,
137+
+ std::ostream &os) { // NOLINT(*)
138+
+ DataType from_ty = op->value.dtype();
139+
+ DataType target_ty = op->dtype;
140+
+ if (target_ty.is_float8() || from_ty.is_float8()) {
141+
+ enable_fp8_ = true;
142+
+ ICHECK_EQ(target_ty.lanes(), from_ty.lanes())
143+
+ << "FP8 vector cast lanes must match: " << from_ty << " -> "
144+
+ << target_ty;
145+
+ auto fp8_to_half = [&](DataType ft, std::string val) {
146+
+ const char *helper = ft.code() == DataType::kFloat8_e5m2
147+
+ ? "__tvm_fp8_e5m2_to_half"
148+
+ : "__tvm_fp8_e4m3_to_half";
149+
+ return std::string(helper) + "(" + val + ")";
150+
+ };
151+
+ auto half_to_fp8 = [&](DataType tt, std::string val) {
152+
+ const char *helper = tt.code() == DataType::kFloat8_e5m2
153+
+ ? "__tvm_half_to_fp8_e5m2"
154+
+ : "__tvm_half_to_fp8_e4m3";
155+
+ return std::string(helper) + "(" + val + ")";
156+
+ };
157+
+ if (target_ty.lanes() == 1) {
158+
+ std::string val = PrintExpr(op->value);
159+
+ if (from_ty.is_float8() && !target_ty.is_float8()) {
160+
+ std::string h = fp8_to_half(from_ty, val);
161+
+ if (target_ty == DataType::Float(16)) {
162+
+ os << h;
163+
+ } else {
164+
+ os << "((";
165+
+ PrintType(target_ty, os);
166+
+ os << ")(" << h << "))";
167+
+ }
168+
+ } else if (!from_ty.is_float8() && target_ty.is_float8()) {
169+
+ std::string h = from_ty == DataType::Float(16)
170+
+ ? val
171+
+ : "((half)(" + val + "))";
172+
+ os << half_to_fp8(target_ty, h);
173+
+ } else {
174+
+ std::string h = fp8_to_half(from_ty, val);
175+
+ os << half_to_fp8(target_ty, h);
176+
+ }
177+
+ return;
178+
+ }
179+
+ LOG(FATAL) << "Vector FP8 casts (lanes=" << target_ty.lanes()
180+
+ << ") are not yet supported by Metal storage-only FP8 emulation;"
181+
+ << " scalarise the cast or extend codegen_metal.cc.";
182+
+ }
183+
+ CodeGenC::VisitExpr_(op, os);
184+
+}
185+
+
186+
+std::string CodeGenTileLangMetal::Finish() {
187+
+ std::ostringstream prelude;
188+
+ if (enable_fp8_) {
189+
+ PrintFP8Prelude(prelude);
190+
+ }
191+
+ std::string base = CodeGenC::Finish();
192+
+ if (prelude.str().empty())
193+
+ return base;
194+
+ const std::string anchor = "using namespace metal;\n";
195+
+ auto pos = base.find(anchor);
196+
+ if (pos == std::string::npos) {
197+
+ return prelude.str() + base;
198+
+ }
199+
+ pos += anchor.size();
200+
+ return base.substr(0, pos) + "\n" + prelude.str() + base.substr(pos);
201+
+}
202+
+
203+
void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op,
204+
std::ostream &os) { // NOLINT(*)
205+
std::ostringstream temp;
206+
diff --git a/src/target/codegen_metal.h b/src/target/codegen_metal.h
207+
index 3a711b4e..f2f41e40 100644
208+
--- a/src/target/codegen_metal.h
209+
+++ b/src/target/codegen_metal.h
210+
@@ -60,15 +60,25 @@ public:
211+
void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*)
212+
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
213+
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
214+
+ void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
215+
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
216+
217+
+ // Override to inject FP8 prelude (storage-only emulation helpers) when
218+
+ // any FP8 dtype was referenced.
219+
+ std::string Finish() final;
220+
+
221+
// reuse parent's function.
222+
using CodeGenC::PrintType;
223+
224+
private:
225+
+ // Emit inline MSL helpers for storage-only FP8 (e4m3 / e5m2) emulation.
226+
+ void PrintFP8Prelude(std::ostream &os);
227+
+
228+
std::unordered_map<const VarNode *, std::string> simdgroup_dtype_;
229+
int thread_index_bits_{32};
230+
int thread_work_dim_{0};
231+
+ // Set when an FP8 dtype is referenced; gates emission of FP8 prelude helpers.
232+
+ bool enable_fp8_{false};
233+
Target target_;
234+
};
235+
} // namespace codegen

0 commit comments

Comments
 (0)