|
| 1 | +diff --git a/src/target/codegen_metal.cc b/src/target/codegen_metal.cc |
| 2 | +index faa43a35..025a98e2 100644 |
| 3 | +--- a/src/target/codegen_metal.cc |
| 4 | ++++ b/src/target/codegen_metal.cc |
| 5 | +@@ -57,6 +57,98 @@ CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) { |
| 6 | + << "};\n\n"; |
| 7 | + } |
| 8 | + |
| 9 | ++// Inline MSL helpers for storage-only FP8 emulation (e4m3 / e5m2). |
| 10 | ++// Apple Silicon (M4 Max and earlier; M5 NAX is FP16/INT8 only) has NO native |
| 11 | ++// FP8 ALU support, so FP8 is realised as `uchar` storage with explicit |
| 12 | ++// dequantize-on-load / quantize-on-store. The helpers mirror the IEEE 754 |
| 13 | ++// derived encoding from the OFP8 spec (E4M3 with finite-only encoding, E5M2 |
| 14 | ++// IEEE-style with NaN/Inf). |
| 15 | ++void CodeGenTileLangMetal::PrintFP8Prelude(std::ostream &os) { |
| 16 | ++ os << |
| 17 | ++ "// FP8 storage-only emulation helpers (MSL has no native float8 type).\n" |
| 18 | ++ "// See OCP \"OFP8 Formats for Deep Learning\" v1.0 spec.\n" |
| 19 | ++ "inline half __tvm_fp8_e4m3_to_half(uchar x) {\n" |
| 20 | ++ " ushort sign = (ushort)(x & 0x80) << 8;\n" |
| 21 | ++ " ushort mant = (ushort)(x & 0x07);\n" |
| 22 | ++ " ushort exp = (ushort)((x >> 3) & 0x0F);\n" |
| 23 | ++ " ushort h;\n" |
| 24 | ++ " if (exp == 0) {\n" |
| 25 | ++ " if (mant == 0) {\n" |
| 26 | ++ " h = sign;\n" |
| 27 | ++ " } else {\n" |
| 28 | ++ " // subnormal: e4m3 value = mant * 2^-9. After shifting the\n" |
| 29 | ++ " // mantissa so the leading 1 hits bit 2 (0x4), the half\n" |
| 30 | ++ " // biased exponent is (e + 7), not (e + 8).\n" |
| 31 | ++ " ushort m = mant;\n" |
| 32 | ++ " ushort e = 1;\n" |
| 33 | ++ " while ((m & 0x4) == 0) { m <<= 1; e -= 1; }\n" |
| 34 | ++ " m &= 0x3;\n" |
| 35 | ++ " h = (ushort)(sign | ((ushort)(e + 7) << 10) | (ushort)(m << 8));\n" |
| 36 | ++ " }\n" |
| 37 | ++ " } else if (exp == 0x0F && mant == 0x07) {\n" |
| 38 | ++ " h = (ushort)(sign | 0x7E00);\n" |
| 39 | ++ " } else {\n" |
| 40 | ++ " h = (ushort)(sign | ((ushort)(exp + 8) << 10) | (ushort)(mant << 7));\n" |
| 41 | ++ " }\n" |
| 42 | ++ " return as_type<half>(h);\n" |
| 43 | ++ "}\n" |
| 44 | ++ "inline half __tvm_fp8_e5m2_to_half(uchar x) {\n" |
| 45 | ++ " ushort h = ((ushort)x) << 8;\n" |
| 46 | ++ " return as_type<half>(h);\n" |
| 47 | ++ "}\n" |
| 48 | ++ "inline uchar __tvm_half_to_fp8_e4m3(half v) {\n" |
| 49 | ++ " ushort h = as_type<ushort>(v);\n" |
| 50 | ++ " ushort sign = (h >> 8) & 0x80;\n" |
| 51 | ++ " short he = (short)((h >> 10) & 0x1F);\n" |
| 52 | ++ " ushort hm = h & 0x3FF;\n" |
| 53 | ++ " if (he == 0x1F) {\n" |
| 54 | ++ " return (uchar)(sign | 0x7F);\n" |
| 55 | ++ " }\n" |
| 56 | ++ " short e = he - 8;\n" |
| 57 | ++ " if (e >= 0x0F) {\n" |
| 58 | ++ " return (uchar)(sign | 0x7E);\n" |
| 59 | ++ " }\n" |
| 60 | ++ " if (e <= 0) {\n" |
| 61 | ++ " if (e < -3) return (uchar)sign;\n" |
| 62 | ++ " ushort m = hm | 0x400;\n" |
| 63 | ++ " ushort shift = (ushort)(7 + 1 - e);\n" |
| 64 | ++ " ushort round_bit = (ushort)1 << (shift - 1);\n" |
| 65 | ++ " ushort sticky = m & (round_bit - 1);\n" |
| 66 | ++ " ushort q = m >> shift;\n" |
| 67 | ++ " ushort rem = m & ((round_bit << 1) - 1);\n" |
| 68 | ++ " if (rem > round_bit || (rem == round_bit && (q & 1))) q += 1;\n" |
| 69 | ++ " (void)sticky;\n" |
| 70 | ++ " return (uchar)(sign | (q & 0x7F));\n" |
| 71 | ++ " }\n" |
| 72 | ++ " ushort q = hm >> 7;\n" |
| 73 | ++ " ushort rem = hm & 0x7F;\n" |
| 74 | ++ " if (rem > 0x40 || (rem == 0x40 && (q & 1))) {\n" |
| 75 | ++ " q += 1;\n" |
| 76 | ++ " if (q == 0x08) { q = 0; e += 1; }\n" |
| 77 | ++ " if (e >= 0x0F) return (uchar)(sign | 0x7E);\n" |
| 78 | ++ " }\n" |
| 79 | ++ " return (uchar)(sign | (ushort)(e << 3) | (q & 0x07));\n" |
| 80 | ++ "}\n" |
| 81 | ++ "inline uchar __tvm_half_to_fp8_e5m2(half v) {\n" |
| 82 | ++ " ushort h = as_type<ushort>(v);\n" |
| 83 | ++ " ushort sign = h & 0x8000;\n" |
| 84 | ++ " ushort exp = (h >> 10) & 0x1F;\n" |
| 85 | ++ " ushort mant = h & 0x3FF;\n" |
| 86 | ++ " if (exp == 0x1F) {\n" |
| 87 | ++ " if (mant != 0) return (uchar)((sign >> 8) | 0x7E);\n" |
| 88 | ++ " return (uchar)((sign >> 8) | 0x7C);\n" |
| 89 | ++ " }\n" |
| 90 | ++ " ushort q = mant >> 8;\n" |
| 91 | ++ " ushort rem = mant & 0xFF;\n" |
| 92 | ++ " if (rem > 0x80 || (rem == 0x80 && (q & 1))) {\n" |
| 93 | ++ " q += 1;\n" |
| 94 | ++ " if (q == 0x4) { q = 0; exp += 1; }\n" |
| 95 | ++ " if (exp == 0x1F) return (uchar)((sign >> 8) | 0x7C);\n" |
| 96 | ++ " }\n" |
| 97 | ++ " return (uchar)((sign >> 8) | (uchar)(exp << 2) | (uchar)(q & 0x3));\n" |
| 98 | ++ "}\n\n"; |
| 99 | ++} |
| 100 | ++ |
| 101 | + void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar, |
| 102 | + const PrimFunc &func) { |
| 103 | + // NOTE: There is no inter-function calls among Metal kernels. |
| 104 | +@@ -275,6 +367,27 @@ void CodeGenTileLangMetal::PrintType(DataType t, |
| 105 | + } else if (t.is_bfloat16()) { |
| 106 | + os << "bfloat"; |
| 107 | + return; |
| 108 | ++ } else if (t.is_float8()) { |
| 109 | ++ // FP8 is storage-only on Metal: print as `uchar`/`ucharN` and emit explicit |
| 110 | ++ // dequantize/quantize helpers via the FP8 prelude. Caller-side casts must |
| 111 | ++ // route through __tvm_fp8_*_to_half / __tvm_half_to_fp8_*. |
| 112 | ++ enable_fp8_ = true; |
| 113 | ++ if (lanes == 1) { |
| 114 | ++ os << "uchar"; |
| 115 | ++ return; |
| 116 | ++ } |
| 117 | ++ if (lanes >= 2 && lanes <= 4) { |
| 118 | ++ os << "uchar" << lanes; |
| 119 | ++ return; |
| 120 | ++ } |
| 121 | ++ if (lanes == 8) { |
| 122 | ++ os << "uint2"; |
| 123 | ++ return; |
| 124 | ++ } |
| 125 | ++ if (lanes == 16) { |
| 126 | ++ os << "uint4"; |
| 127 | ++ return; |
| 128 | ++ } |
| 129 | + } |
| 130 | + LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; |
| 131 | + } |
| 132 | +@@ -517,6 +630,73 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op, |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | ++void CodeGenTileLangMetal::VisitExpr_(const CastNode *op, |
| 137 | ++ std::ostream &os) { // NOLINT(*) |
| 138 | ++ DataType from_ty = op->value.dtype(); |
| 139 | ++ DataType target_ty = op->dtype; |
| 140 | ++ if (target_ty.is_float8() || from_ty.is_float8()) { |
| 141 | ++ enable_fp8_ = true; |
| 142 | ++ ICHECK_EQ(target_ty.lanes(), from_ty.lanes()) |
| 143 | ++ << "FP8 vector cast lanes must match: " << from_ty << " -> " |
| 144 | ++ << target_ty; |
| 145 | ++ auto fp8_to_half = [&](DataType ft, std::string val) { |
| 146 | ++ const char *helper = ft.code() == DataType::kFloat8_e5m2 |
| 147 | ++ ? "__tvm_fp8_e5m2_to_half" |
| 148 | ++ : "__tvm_fp8_e4m3_to_half"; |
| 149 | ++ return std::string(helper) + "(" + val + ")"; |
| 150 | ++ }; |
| 151 | ++ auto half_to_fp8 = [&](DataType tt, std::string val) { |
| 152 | ++ const char *helper = tt.code() == DataType::kFloat8_e5m2 |
| 153 | ++ ? "__tvm_half_to_fp8_e5m2" |
| 154 | ++ : "__tvm_half_to_fp8_e4m3"; |
| 155 | ++ return std::string(helper) + "(" + val + ")"; |
| 156 | ++ }; |
| 157 | ++ if (target_ty.lanes() == 1) { |
| 158 | ++ std::string val = PrintExpr(op->value); |
| 159 | ++ if (from_ty.is_float8() && !target_ty.is_float8()) { |
| 160 | ++ std::string h = fp8_to_half(from_ty, val); |
| 161 | ++ if (target_ty == DataType::Float(16)) { |
| 162 | ++ os << h; |
| 163 | ++ } else { |
| 164 | ++ os << "(("; |
| 165 | ++ PrintType(target_ty, os); |
| 166 | ++ os << ")(" << h << "))"; |
| 167 | ++ } |
| 168 | ++ } else if (!from_ty.is_float8() && target_ty.is_float8()) { |
| 169 | ++ std::string h = from_ty == DataType::Float(16) |
| 170 | ++ ? val |
| 171 | ++ : "((half)(" + val + "))"; |
| 172 | ++ os << half_to_fp8(target_ty, h); |
| 173 | ++ } else { |
| 174 | ++ std::string h = fp8_to_half(from_ty, val); |
| 175 | ++ os << half_to_fp8(target_ty, h); |
| 176 | ++ } |
| 177 | ++ return; |
| 178 | ++ } |
| 179 | ++ LOG(FATAL) << "Vector FP8 casts (lanes=" << target_ty.lanes() |
| 180 | ++ << ") are not yet supported by Metal storage-only FP8 emulation;" |
| 181 | ++ << " scalarise the cast or extend codegen_metal.cc."; |
| 182 | ++ } |
| 183 | ++ CodeGenC::VisitExpr_(op, os); |
| 184 | ++} |
| 185 | ++ |
| 186 | ++std::string CodeGenTileLangMetal::Finish() { |
| 187 | ++ std::ostringstream prelude; |
| 188 | ++ if (enable_fp8_) { |
| 189 | ++ PrintFP8Prelude(prelude); |
| 190 | ++ } |
| 191 | ++ std::string base = CodeGenC::Finish(); |
| 192 | ++ if (prelude.str().empty()) |
| 193 | ++ return base; |
| 194 | ++ const std::string anchor = "using namespace metal;\n"; |
| 195 | ++ auto pos = base.find(anchor); |
| 196 | ++ if (pos == std::string::npos) { |
| 197 | ++ return prelude.str() + base; |
| 198 | ++ } |
| 199 | ++ pos += anchor.size(); |
| 200 | ++ return base.substr(0, pos) + "\n" + prelude.str() + base.substr(pos); |
| 201 | ++} |
| 202 | ++ |
| 203 | + void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op, |
| 204 | + std::ostream &os) { // NOLINT(*) |
| 205 | + std::ostringstream temp; |
| 206 | +diff --git a/src/target/codegen_metal.h b/src/target/codegen_metal.h |
| 207 | +index 3a711b4e..f2f41e40 100644 |
| 208 | +--- a/src/target/codegen_metal.h |
| 209 | ++++ b/src/target/codegen_metal.h |
| 210 | +@@ -60,15 +60,25 @@ public: |
| 211 | + void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*) |
| 212 | + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) |
| 213 | + void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*) |
| 214 | ++ void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*) |
| 215 | + void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*) |
| 216 | + |
| 217 | ++ // Override to inject FP8 prelude (storage-only emulation helpers) when |
| 218 | ++ // any FP8 dtype was referenced. |
| 219 | ++ std::string Finish() final; |
| 220 | ++ |
| 221 | + // reuse parent's function. |
| 222 | + using CodeGenC::PrintType; |
| 223 | + |
| 224 | + private: |
| 225 | ++ // Emit inline MSL helpers for storage-only FP8 (e4m3 / e5m2) emulation. |
| 226 | ++ void PrintFP8Prelude(std::ostream &os); |
| 227 | ++ |
| 228 | + std::unordered_map<const VarNode *, std::string> simdgroup_dtype_; |
| 229 | + int thread_index_bits_{32}; |
| 230 | + int thread_work_dim_{0}; |
| 231 | ++ // Set when an FP8 dtype is referenced; gates emission of FP8 prelude helpers. |
| 232 | ++ bool enable_fp8_{false}; |
| 233 | + Target target_; |
| 234 | + }; |
| 235 | + } // namespace codegen |
0 commit comments