|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <sycl/sycl.hpp> |
| 4 | +#include <cstdint> |
| 5 | +#include <limits> |
| 6 | + |
| 7 | +inline uint8_t float_to_e4m3(float f) |
| 8 | +{ |
| 9 | + if (sycl::isnan(f)) { |
| 10 | + return 0x7F; // Canonical NaN (positive) |
| 11 | + } |
| 12 | + |
| 13 | + uint32_t bits = sycl::bit_cast<uint32_t>(f); |
| 14 | + uint32_t sign = (bits >> 31) & 0x1u; |
| 15 | + uint32_t exp = (bits >> 23) & 0xFFu; |
| 16 | + uint32_t mant = bits & 0x7FFFFFu; |
| 17 | + |
| 18 | + // Zero |
| 19 | + if (exp == 0 && mant == 0) { |
| 20 | + return static_cast<uint8_t>(sign << 7); |
| 21 | + } |
| 22 | + |
| 23 | + // Extract biased exponent and mantissa for FP8 |
| 24 | + int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127) |
| 25 | + uint32_t m = mant; |
| 26 | + |
| 27 | + // Handle very large values → NaN (NVIDIA behavior for E4M3) |
| 28 | + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) |
| 29 | + return static_cast<uint8_t>((sign << 7) | 0x7F); |
| 30 | + } |
| 31 | + |
| 32 | + // Handle subnormals and normal numbers |
| 33 | + if (e < -6) { // smallest normal exponent is -6 |
| 34 | + // Subnormal in FP8: shift mantissa right |
| 35 | + int shift = -6 - e; |
| 36 | + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position |
| 37 | + if (shift > 23) m = 0; |
| 38 | + } else { |
| 39 | + // Normal number: adjust exponent bias from 127 to 7 |
| 40 | + int new_exp = e + 7; |
| 41 | + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) |
| 42 | + m |= (static_cast<uint32_t>(new_exp) << 3); |
| 43 | + } |
| 44 | + |
| 45 | + // Round-to-nearest-even (simple guard + round bit) |
| 46 | + // For better accuracy you can add sticky bit, but this is sufficient for most use cases |
| 47 | + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits |
| 48 | + if (round_bit) { |
| 49 | + m += 1; |
| 50 | + // Carry into exponent if mantissa overflows |
| 51 | + if ((m & 0x8u) != 0) { |
| 52 | + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling |
| 53 | + // If exponent overflows after carry → NaN |
| 54 | + if ((m >> 3) > 14) { |
| 55 | + return static_cast<uint8_t>((sign << 7) | 0x7F); |
| 56 | + } |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F)); |
| 61 | + return result; |
| 62 | +} |
| 63 | + |
| 64 | +inline float e4m3_to_float(uint8_t x) |
| 65 | +{ |
| 66 | + if (x == 0) return 0.0f; |
| 67 | + |
| 68 | + uint8_t sign = (x >> 7) & 0x1u; |
| 69 | + uint8_t exp = (x >> 3) & 0xFu; |
| 70 | + uint8_t mant = x & 0x7u; |
| 71 | + |
| 72 | + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) |
| 73 | + if (exp == 0xF && mant != 0) { |
| 74 | + return std::numeric_limits<float>::quiet_NaN(); |
| 75 | + } |
| 76 | + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN |
| 77 | + return std::numeric_limits<float>::quiet_NaN(); |
| 78 | + } |
| 79 | + |
| 80 | + float val; |
| 81 | + |
| 82 | + if (exp == 0) { |
| 83 | + // Subnormal |
| 84 | + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); |
| 85 | + } else { |
| 86 | + // Normal: implicit leading 1 + bias 7 |
| 87 | + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f); |
| 88 | + } |
| 89 | + |
| 90 | + return sign ? -val : val; |
| 91 | +} |
| 92 | + |
| 93 | +// The actual type definition |
| 94 | +struct __nv_fp8_e4m3 { |
| 95 | + uint8_t raw; |
| 96 | + |
| 97 | + __nv_fp8_e4m3() = default; |
| 98 | + |
| 99 | + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} |
| 100 | + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {} |
| 101 | + |
| 102 | + operator float() const { return e4m3_to_float(raw); } |
| 103 | + operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); } |
| 104 | + |
| 105 | + // Allow direct access for vector loads/stores |
| 106 | + operator uint8_t&() { return raw; } |
| 107 | + operator uint8_t() const { return raw; } |
| 108 | +}; |
| 109 | + |
| 110 | +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; |
| 111 | +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; |
| 112 | + |
0 commit comments