|
| 1 | +// --------------------------------------------------------------------------------------- |
| 2 | +// ILGPU |
| 3 | +// |
| 4 | +// File: Float4E2M1.cs |
| 5 | +// |
| 6 | +// The kernel-native 4-bit floating-point type in the OCP "E2M1" layout (E2M1FN, the finite |
| 7 | +// ML variant; the element format of NVFP4 / MXFP4): 1 sign / 2 exponent / 1 mantissa bits, |
| 8 | +// exponent bias 1. ALL 16 codes are finite - NO infinities, NO NaN. The representable |
| 9 | +// magnitudes are exactly {0, 0.5, 1, 1.5, 2, 3, 4, 6}; max is 6 (0x7 / 0xF). |
| 10 | +// |
| 11 | +// Bit-exact to `ml_dtypes.float4_e2m1fn` (PyTorch/JAX share it), verified by |
| 12 | +// `DemoConsole -- fp4-oracle`: |
| 13 | +// code 0x0..0x7 = +{0,.5,1,1.5,2,3,4,6}; 0x8..0xF = the negatives (sign bit = bit 3). |
| 14 | +// encode is round-to-nearest-even among the 16 values; finite overflow AND +-Inf SATURATE |
| 15 | +// to +-6; NaN -> 0x8 (the format has no NaN encoding; ml_dtypes maps NaN -> -0, matched here). |
| 16 | +// |
| 17 | +// STORAGE = 1 byte (the 4-bit value in the low nibble), exactly like Float8E4M3/E5M2 - it reuses |
| 18 | +// the existing 1-byte sub-word machinery on every backend. (The IR type-size model is byte- |
| 19 | +// granular; true 4-bit nibble packing belongs in the MXFP4/NF4 block-dequant layer, not here.) |
| 20 | +// |
| 21 | +// Modeled on ILGPU.Float8E4M3: FP32-based [MathIntrinsic]/[CompareIntrinisc]/[ConvertIntrinisc] |
| 22 | +// operators (transpiled on every backend). |
| 23 | +// --------------------------------------------------------------------------------------- |
| 24 | + |
| 25 | +using ILGPU.Frontend.Intrinsic; |
| 26 | +using ILGPU.IR.Values; |
| 27 | +using ILGPU.Util; |
| 28 | +using System; |
| 29 | +#if !DEBUG |
| 30 | +using System.Diagnostics; |
| 31 | +#endif |
| 32 | +using System.Runtime.CompilerServices; |
| 33 | + |
| 34 | +namespace ILGPU |
| 35 | +{ |
| 36 | + /// <summary> |
| 37 | + /// A 4-bit floating-point value in OCP E2M1 (E2M1FN) layout (1 sign, 2 exponent, 1 mantissa, |
| 38 | + /// bias 1). All 16 codes finite (NO Inf/NaN); magnitudes {0,.5,1,1.5,2,3,4,6}, max 6. The |
| 39 | + /// NVFP4/MXFP4 element format. 1-byte storage (value in the low nibble). |
| 40 | + /// </summary> |
| 41 | + [Serializable] |
| 42 | + public readonly partial struct Float4E2M1 : |
| 43 | + IEquatable<Float4E2M1>, IComparable<Float4E2M1> |
| 44 | + { |
| 45 | + #region Static |
| 46 | + |
| 47 | + /// <summary>Returns the absolute value of the given E2M1 value.</summary> |
| 48 | + [MathIntrinsic(MathIntrinsicKind.Abs)] |
| 49 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 50 | + public static Float4E2M1 Abs(Float4E2M1 value) => Float4E2M1Extensions.Abs(value); |
| 51 | + |
| 52 | + /// <summary>Returns true if the given E2M1 value represents 0 (either sign).</summary> |
| 53 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 54 | + public static bool IsZero(Float4E2M1 value) => Float4E2M1Extensions.IsZero(value); |
| 55 | + |
| 56 | + /// <summary>Returns true always - E2M1 has no Inf or NaN (every code is finite).</summary> |
| 57 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 58 | + public static bool IsFinite(Float4E2M1 value) => true; |
| 59 | + |
| 60 | + /// <summary> |
| 61 | + /// Converts a float to E2M1. Round-to-nearest-even among the 16 values; finite overflow and |
| 62 | + /// +-Inf saturate to +-6; NaN -> -0 (bit-exact to ml_dtypes float4_e2m1fn). Same as the cast. |
| 63 | + /// </summary> |
| 64 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 65 | + public static Float4E2M1 FromSingle(float value) => |
| 66 | + Float4E2M1Extensions.ConvertFloatToFloat4E2M1(value); |
| 67 | + |
| 68 | + #endregion |
| 69 | + |
| 70 | + #region Constants |
| 71 | + |
| 72 | + /// <summary>Positive zero (0x0).</summary> |
| 73 | + public static readonly Float4E2M1 Zero = new Float4E2M1(0x0); |
| 74 | + |
| 75 | + /// <summary>The value one (exp=1, mant=0 -> 0x2).</summary> |
| 76 | + public static readonly Float4E2M1 One = new Float4E2M1(0x2); |
| 77 | + |
| 78 | + /// <summary>The smallest positive value (0.5 subnormal, 0x1).</summary> |
| 79 | + public static readonly Float4E2M1 Epsilon = new Float4E2M1(0x1); |
| 80 | + |
| 81 | + /// <summary>The largest finite value (6.0, 0x7). E2M1 has no Inf.</summary> |
| 82 | + public static readonly Float4E2M1 MaxValue = new Float4E2M1(0x7); |
| 83 | + |
| 84 | + /// <summary>The smallest finite value (-6.0, 0xF).</summary> |
| 85 | + public static readonly Float4E2M1 MinValue = new Float4E2M1(0xF); |
| 86 | + |
| 87 | + #endregion |
| 88 | + |
| 89 | + #region Instance |
| 90 | + |
| 91 | + /// <summary>Constructs a new E2M1 value from its raw 4-bit pattern (low nibble of the byte).</summary> |
| 92 | + internal Float4E2M1(byte rawValue) |
| 93 | + { |
| 94 | + RawValue = (byte)(rawValue & 0x0F); |
| 95 | + } |
| 96 | + |
| 97 | + #endregion |
| 98 | + |
| 99 | + #region Properties |
| 100 | + |
| 101 | + /// <summary>The raw 4-bit value (stored in the low nibble of a byte).</summary> |
| 102 | +#if !DEBUG |
| 103 | + [DebuggerBrowsable(DebuggerBrowsableState.Never)] |
| 104 | +#endif |
| 105 | + internal byte RawValue { get; } |
| 106 | + |
| 107 | + #endregion |
| 108 | + |
| 109 | + #region IEquatable / IComparable / Object |
| 110 | + |
| 111 | + /// <summary>Returns true if the given E2M1 is equal to the current value (by float value).</summary> |
| 112 | + public readonly bool Equals(Float4E2M1 other) => (float)this == other; |
| 113 | + |
| 114 | + /// <summary>Compares this E2M1 value to the given one (by float value).</summary> |
| 115 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 116 | + public readonly int CompareTo(Float4E2M1 other) => ((float)this).CompareTo(other); |
| 117 | + |
| 118 | + /// <summary>Returns true if the given object is equal to the current value.</summary> |
| 119 | + public readonly override bool Equals(object? obj) => |
| 120 | + obj is Float4E2M1 value && Equals(value); |
| 121 | + |
| 122 | + /// <summary>Returns the hash code of this value.</summary> |
| 123 | + public readonly override int GetHashCode() => RawValue; |
| 124 | + |
| 125 | + /// <summary>Returns the string representation of this value.</summary> |
| 126 | + public readonly override string ToString() => ((float)this).ToString(); |
| 127 | + |
| 128 | + #endregion |
| 129 | + |
| 130 | + #region Operators |
| 131 | + |
| 132 | + /// <summary>Negates the given E2M1 value (flip the sign bit).</summary> |
| 133 | + [MathIntrinsic(MathIntrinsicKind.Neg)] |
| 134 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 135 | + public static Float4E2M1 operator -(Float4E2M1 value) => Float4E2M1Extensions.Neg(value); |
| 136 | + |
| 137 | + /// <summary>Adds two E2M1 values.</summary> |
| 138 | + [MathIntrinsic(MathIntrinsicKind.Add)] |
| 139 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 140 | + public static Float4E2M1 operator +(Float4E2M1 first, Float4E2M1 second) => |
| 141 | + (Float4E2M1)((float)first + second); |
| 142 | + |
| 143 | + /// <summary>Subtracts two E2M1 values.</summary> |
| 144 | + [MathIntrinsic(MathIntrinsicKind.Sub)] |
| 145 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 146 | + public static Float4E2M1 operator -(Float4E2M1 first, Float4E2M1 second) => |
| 147 | + (Float4E2M1)((float)first - second); |
| 148 | + |
| 149 | + /// <summary>Multiplies two E2M1 values.</summary> |
| 150 | + [MathIntrinsic(MathIntrinsicKind.Mul)] |
| 151 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 152 | + public static Float4E2M1 operator *(Float4E2M1 first, Float4E2M1 second) => |
| 153 | + (Float4E2M1)((float)first * second); |
| 154 | + |
| 155 | + /// <summary>Divides two E2M1 values.</summary> |
| 156 | + [MathIntrinsic(MathIntrinsicKind.Div)] |
| 157 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 158 | + public static Float4E2M1 operator /(Float4E2M1 first, Float4E2M1 second) => |
| 159 | + (Float4E2M1)((float)first / second); |
| 160 | + |
| 161 | + /// <summary>Returns true if the two values are equal.</summary> |
| 162 | + [CompareIntrinisc(CompareKind.Equal)] |
| 163 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 164 | + public static bool operator ==(Float4E2M1 first, Float4E2M1 second) => |
| 165 | + (float)first == second; |
| 166 | + |
| 167 | + /// <summary>Returns true if the two values are not equal.</summary> |
| 168 | + [CompareIntrinisc(CompareKind.NotEqual, CompareFlags.UnsignedOrUnordered)] |
| 169 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 170 | + public static bool operator !=(Float4E2M1 first, Float4E2M1 second) => |
| 171 | + (float)first != second; |
| 172 | + |
| 173 | + /// <summary>Returns true if the first value is smaller than the second.</summary> |
| 174 | + [CompareIntrinisc(CompareKind.LessThan)] |
| 175 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 176 | + public static bool operator <(Float4E2M1 first, Float4E2M1 second) => |
| 177 | + (float)first < second; |
| 178 | + |
| 179 | + /// <summary>Returns true if the first value is smaller than or equal to the second.</summary> |
| 180 | + [CompareIntrinisc(CompareKind.LessEqual)] |
| 181 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 182 | + public static bool operator <=(Float4E2M1 first, Float4E2M1 second) => |
| 183 | + (float)first <= second; |
| 184 | + |
| 185 | + /// <summary>Returns true if the first value is greater than the second.</summary> |
| 186 | + [CompareIntrinisc(CompareKind.GreaterThan)] |
| 187 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 188 | + public static bool operator >(Float4E2M1 first, Float4E2M1 second) => |
| 189 | + (float)first > second; |
| 190 | + |
| 191 | + /// <summary>Returns true if the first value is greater than or equal to the second.</summary> |
| 192 | + [CompareIntrinisc(CompareKind.GreaterEqual)] |
| 193 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 194 | + public static bool operator >=(Float4E2M1 first, Float4E2M1 second) => |
| 195 | + (float)first >= second; |
| 196 | + |
| 197 | + /// <summary>Implicitly converts an E2M1 to a float.</summary> |
| 198 | + [ConvertIntrinisc] |
| 199 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 200 | + public static implicit operator float(Float4E2M1 value) => |
| 201 | + Float4E2M1Extensions.ConvertFloat4E2M1ToFloat(value); |
| 202 | + |
| 203 | + /// <summary>Implicitly converts an E2M1 to a double.</summary> |
| 204 | + [ConvertIntrinisc] |
| 205 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 206 | + public static implicit operator double(Float4E2M1 value) => (float)value; |
| 207 | + |
| 208 | + /// <summary>Explicitly converts a float to an E2M1.</summary> |
| 209 | + [ConvertIntrinisc] |
| 210 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 211 | + public static explicit operator Float4E2M1(float value) => |
| 212 | + Float4E2M1Extensions.ConvertFloatToFloat4E2M1(value); |
| 213 | + |
| 214 | + /// <summary>Explicitly converts a double to an E2M1.</summary> |
| 215 | + [ConvertIntrinisc] |
| 216 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 217 | + public static explicit operator Float4E2M1(double value) => |
| 218 | + (Float4E2M1)(float)value; |
| 219 | + |
| 220 | + #endregion |
| 221 | + } |
| 222 | + |
| 223 | + /// <summary> |
| 224 | + /// Extension/implementation methods for the <see cref="Float4E2M1"/> type. |
| 225 | + /// </summary> |
| 226 | + public static partial class Float4E2M1Extensions |
| 227 | + { |
| 228 | + #region Constants |
| 229 | + |
| 230 | + private const byte SignBitMask = 0x8; // bit 3 |
| 231 | + private const byte MagnitudeMask = 0x7; // exp(2) + mantissa(1) |
| 232 | + private const byte MaxFiniteMagnitude = 0x7; // 6.0 |
| 233 | + |
| 234 | + #endregion |
| 235 | + |
| 236 | + #region Conversion |
| 237 | + |
| 238 | + /// <summary>Converts an E2M1 value to a float (rebias 1 -> 127; 1 mantissa bit; no Inf/NaN).</summary> |
| 239 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 240 | + public static float ConvertFloat4E2M1ToFloat(Float4E2M1 value) |
| 241 | + { |
| 242 | + uint code = value.RawValue; |
| 243 | + uint sign = (code & 0x8u) << 28; // f32 sign bit (bit 31) |
| 244 | + uint e = (code >> 1) & 0x3u; // exponent field (2 bits) |
| 245 | + uint m = code & 0x1u; // mantissa (1 bit) |
| 246 | + |
| 247 | + if (e == 0u) |
| 248 | + { |
| 249 | + if (m == 0u) |
| 250 | + return Interop.IntAsFloat(sign); // +-0 |
| 251 | + // Subnormal 0.5 = 2^-1: f32 exponent 126, mantissa 0. |
| 252 | + return Interop.IntAsFloat(sign | (126u << 23)); |
| 253 | + } |
| 254 | + // Normal: value = 1.m * 2^(e-1). f32 exp = (e-1)+127; the single mantissa bit -> bit 22. |
| 255 | + uint f32Exp = e - 1u + 127u; |
| 256 | + return Interop.IntAsFloat(sign | (f32Exp << 23) | (m << 22)); |
| 257 | + } |
| 258 | + |
| 259 | + /// <summary> |
| 260 | + /// Converts a float to an E2M1 value using round-to-nearest-even. Finite overflow and +-Inf |
| 261 | + /// SATURATE to +-6; NaN -> -0 (0x8). Bit-exact to ml_dtypes float4_e2m1fn. |
| 262 | + /// </summary> |
| 263 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 264 | + public static Float4E2M1 ConvertFloatToFloat4E2M1(float value) |
| 265 | + { |
| 266 | + uint bits = Interop.FloatAsInt(value); |
| 267 | + uint sign = (bits >> 28) & 0x8u; // E2M1 sign bit (bit 3) |
| 268 | + uint rest = bits & 0x7FFFFFFFu; |
| 269 | + |
| 270 | + // NaN -> 0x8 (-0); the format has no NaN. (ml_dtypes convention.) |
| 271 | + if (rest > 0x7F800000u) |
| 272 | + return new Float4E2M1((byte)0x8); |
| 273 | + // +-Inf -> saturate to +-6. |
| 274 | + if (rest >= 0x7F800000u) |
| 275 | + return new Float4E2M1((byte)(sign | MaxFiniteMagnitude)); |
| 276 | + |
| 277 | + int f32Exp = (int)((rest >> 23) & 0xFFu); |
| 278 | + uint f32Mant = rest & 0x7FFFFFu; |
| 279 | + int e = f32Exp - 127; // unbiased |
| 280 | + |
| 281 | + // E2M1 normal exponent range: 0..2 (bias 1); max finite 6 at e=2, mant=1. |
| 282 | + // Finite overflow (e>2, or e==2 with mantissa rounding past 1.5) -> saturate to +-6. |
| 283 | + if (e > 2) |
| 284 | + return new Float4E2M1((byte)(sign | MaxFiniteMagnitude)); |
| 285 | + |
| 286 | + if (e < 0) |
| 287 | + { |
| 288 | + // Subnormal (only 0.5 = 2^-1) or zero. signif = 1.f32Mant scaled to the 0.5 grid. |
| 289 | + if (f32Exp == 0) |
| 290 | + return new Float4E2M1((byte)sign); // f32 zero/subnormal -> +-0 |
| 291 | + uint signif = f32Mant | 0x800000u; // implicit 1 (24-bit) |
| 292 | + // Smallest E2M1 step is 0.5 = 2^-1. value = signif * 2^(e-23); we want round(value / 2^-1) |
| 293 | + // = round(signif * 2^(e-23+1)) as the count of 0.5 units, clamped to {0,1} (mantissa bit). |
| 294 | + int shift = (-1 - e) + 23; // align signif to the 2^-1 unit |
| 295 | + if (shift > 31) |
| 296 | + return new Float4E2M1((byte)sign); // underflow -> +-0 |
| 297 | + uint q = signif >> shift; // integer count of 0.5 units |
| 298 | + uint roundBit = (signif >> (shift - 1)) & 1u; |
| 299 | + uint sticky = (signif & ((1u << (shift - 1)) - 1u)) != 0u ? 1u : 0u; |
| 300 | + if (roundBit == 1u && (sticky == 1u || (q & 1u) == 1u)) |
| 301 | + q += 1u; // RNE; q may carry 0->1 or 1->2(=1.0) |
| 302 | + // q==0 -> 0; q==1 -> 0.5 (0x1); q==2 -> 1.0 (0x2 = smallest normal). |
| 303 | + return new Float4E2M1((byte)(sign | (q & 0x7u))); |
| 304 | + } |
| 305 | + |
| 306 | + // Normal range (e in 0..2). Rebias and round the mantissa 23 -> 1 bit (RNE). |
| 307 | + uint mant1 = f32Mant >> 22; // top mantissa bit |
| 308 | + uint round = (f32Mant >> 21) & 1u; // first dropped bit |
| 309 | + uint stick = (f32Mant & 0x1FFFFFu) != 0u ? 1u : 0u; |
| 310 | + uint eField = (uint)(e + 1); // bias 1 |
| 311 | + uint outBits = (eField << 1) | mant1; |
| 312 | + if (round == 1u && (stick == 1u || (mant1 & 1u) == 1u)) |
| 313 | + outBits += 1u; // ties-to-even; may carry into the exponent |
| 314 | + // A carry past the max magnitude (0x7 = 6) saturates to 6 (no larger finite, no Inf). |
| 315 | + if ((outBits & 0x7u) > MaxFiniteMagnitude || outBits > 0x7u) |
| 316 | + outBits = MaxFiniteMagnitude; |
| 317 | + return new Float4E2M1((byte)(sign | (outBits & 0x7u))); |
| 318 | + } |
| 319 | + |
| 320 | + #endregion |
| 321 | + |
| 322 | + #region Predicates |
| 323 | + |
| 324 | + /// <summary>Negates the given E2M1 value (flip the sign bit).</summary> |
| 325 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 326 | + public static Float4E2M1 Neg(Float4E2M1 value) => |
| 327 | + new Float4E2M1((byte)(value.RawValue ^ SignBitMask)); |
| 328 | + |
| 329 | + /// <summary>Returns the absolute value (clear the sign bit).</summary> |
| 330 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 331 | + public static Float4E2M1 Abs(Float4E2M1 value) => |
| 332 | + new Float4E2M1((byte)(value.RawValue & MagnitudeMask)); |
| 333 | + |
| 334 | + /// <summary>Returns true if the value is +-0 (magnitude == 0).</summary> |
| 335 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 336 | + public static bool IsZero(Float4E2M1 value) => |
| 337 | + (value.RawValue & MagnitudeMask) == 0; |
| 338 | + |
| 339 | + #endregion |
| 340 | + |
| 341 | + #region FP32 Implementation Methods |
| 342 | + |
| 343 | + /// <summary>Implements an E2M1 addition using FP32.</summary> |
| 344 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 345 | + public static Float4E2M1 AddFP32(Float4E2M1 first, Float4E2M1 second) => |
| 346 | + (Float4E2M1)((float)first + second); |
| 347 | + |
| 348 | + /// <summary>Implements an E2M1 subtraction using FP32.</summary> |
| 349 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 350 | + public static Float4E2M1 SubFP32(Float4E2M1 first, Float4E2M1 second) => |
| 351 | + (Float4E2M1)((float)first - second); |
| 352 | + |
| 353 | + /// <summary>Implements an E2M1 multiplication using FP32.</summary> |
| 354 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 355 | + public static Float4E2M1 MulFP32(Float4E2M1 first, Float4E2M1 second) => |
| 356 | + (Float4E2M1)((float)first * second); |
| 357 | + |
| 358 | + /// <summary>Implements an E2M1 division using FP32.</summary> |
| 359 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 360 | + public static Float4E2M1 DivFP32(Float4E2M1 first, Float4E2M1 second) => |
| 361 | + (Float4E2M1)((float)first / second); |
| 362 | + |
| 363 | + #endregion |
| 364 | + } |
| 365 | +} |
0 commit comments