Skip to content

Commit 7f07ebb

Browse files
LostBeardclaude
andcommitted
FP4 (Float4E2M1) phase 0a: managed core type + conversion, bit-exact to float4_e2m1fn (CPU-verified)
First increment of the 4-bit tier (Captain wants FP4/INT4/MXFP4/NF4 before SIMD). FP4 E2M1FN = the NVFP4/MXFP4 element format: 1 sign / 2 exp / 1 mantissa, bias 1, 16 finite codes {0,.5,1,1.5,2,3,4,6 ±}, NO Inf/NaN. New ILGPU.Float4E2M1 struct modeled on Float8E4M3 (FP32-based Math/Compare/Convert intrinsic operators + decode/encode bit-manip): RNE encode, finite overflow + ±Inf saturate to ±6, NaN -> 0x8 (-0) per the reference. STORAGE = 1 byte (value in low nibble), the FP8 pattern - NOT 4-bit packed. Decided after reading the IR size model: PrimitiveType.Size is byte-granular int (Int1 is sized 4 bytes!), so a 0.5-byte packed type would need a deep allocation/offset-layer change. The 4-bit memory saving belongs in the MXFP4/NF4 block-dequant layer (packed nibbles in raw buffers + scale, the GGUF Q4_K model), not the per-element core type. Plan corrected accordingly. Verified (CPU/managed, like FP8 phase 0a): DemoConsole -- bf16-f16-oracle CompareFP4 = decode all 16 codes 16/16 + encode probes 108/108 bit-exact to ml_dtypes.float4_e2m1fn. Oracle generator _research/fp8_oracle/gen_float4_oracle.py. WIP - the [*Intrinsic] attributes are managed-only until the IR primitive + 6-backend convert are wired (next increments); Float4E2M1 works on CPU now, GPU-kernel use awaits wiring. No version bump (not a shippable feature yet). NEXT: BasicValueType.Float4E2M1 IR primitive + GenericMath + per-backend 1-byte convert + radix + capability + cross-backend PMT, then INT4. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent efb0669 commit 7f07ebb

5 files changed

Lines changed: 489 additions & 33 deletions

File tree

ILGPU/Float4E2M1.cs

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
// ---------------------------------------------------------------------------------------
2+
// ILGPU
3+
//
4+
// File: Float4E2M1.cs
5+
//
6+
// The kernel-native 4-bit floating-point type in the OCP "E2M1" layout (E2M1FN, the finite
7+
// ML variant; the element format of NVFP4 / MXFP4): 1 sign / 2 exponent / 1 mantissa bits,
8+
// exponent bias 1. ALL 16 codes are finite - NO infinities, NO NaN. The representable
9+
// magnitudes are exactly {0, 0.5, 1, 1.5, 2, 3, 4, 6}; max is 6 (0x7 / 0xF).
10+
//
11+
// Bit-exact to `ml_dtypes.float4_e2m1fn` (PyTorch/JAX share it), verified by
12+
// `DemoConsole -- fp4-oracle`:
13+
// code 0x0..0x7 = +{0,.5,1,1.5,2,3,4,6}; 0x8..0xF = the negatives (sign bit = bit 3).
14+
// encode is round-to-nearest-even among the 16 values; finite overflow AND +-Inf SATURATE
15+
// to +-6; NaN -> 0x8 (the format has no NaN encoding; ml_dtypes maps NaN -> -0, matched here).
16+
//
17+
// STORAGE = 1 byte (the 4-bit value in the low nibble), exactly like Float8E4M3/E5M2 - it reuses
18+
// the existing 1-byte sub-word machinery on every backend. (The IR type-size model is byte-
19+
// granular; true 4-bit nibble packing belongs in the MXFP4/NF4 block-dequant layer, not here.)
20+
//
21+
// Modeled on ILGPU.Float8E4M3: FP32-based [MathIntrinsic]/[CompareIntrinisc]/[ConvertIntrinisc]
22+
// operators (transpiled on every backend).
23+
// ---------------------------------------------------------------------------------------
24+
25+
using ILGPU.Frontend.Intrinsic;
26+
using ILGPU.IR.Values;
27+
using ILGPU.Util;
28+
using System;
29+
#if !DEBUG
30+
using System.Diagnostics;
31+
#endif
32+
using System.Runtime.CompilerServices;
33+
34+
namespace ILGPU
35+
{
36+
/// <summary>
37+
/// A 4-bit floating-point value in OCP E2M1 (E2M1FN) layout (1 sign, 2 exponent, 1 mantissa,
38+
/// bias 1). All 16 codes finite (NO Inf/NaN); magnitudes {0,.5,1,1.5,2,3,4,6}, max 6. The
39+
/// NVFP4/MXFP4 element format. 1-byte storage (value in the low nibble).
40+
/// </summary>
41+
[Serializable]
42+
public readonly partial struct Float4E2M1 :
43+
IEquatable<Float4E2M1>, IComparable<Float4E2M1>
44+
{
45+
#region Static
46+
47+
/// <summary>Returns the absolute value of the given E2M1 value.</summary>
48+
[MathIntrinsic(MathIntrinsicKind.Abs)]
49+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
50+
public static Float4E2M1 Abs(Float4E2M1 value) => Float4E2M1Extensions.Abs(value);
51+
52+
/// <summary>Returns true if the given E2M1 value represents 0 (either sign).</summary>
53+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
54+
public static bool IsZero(Float4E2M1 value) => Float4E2M1Extensions.IsZero(value);
55+
56+
/// <summary>Returns true always - E2M1 has no Inf or NaN (every code is finite).</summary>
57+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
58+
public static bool IsFinite(Float4E2M1 value) => true;
59+
60+
/// <summary>
61+
/// Converts a float to E2M1. Round-to-nearest-even among the 16 values; finite overflow and
62+
/// +-Inf saturate to +-6; NaN -> -0 (bit-exact to ml_dtypes float4_e2m1fn). Same as the cast.
63+
/// </summary>
64+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
65+
public static Float4E2M1 FromSingle(float value) =>
66+
Float4E2M1Extensions.ConvertFloatToFloat4E2M1(value);
67+
68+
#endregion
69+
70+
#region Constants
71+
72+
/// <summary>Positive zero (0x0).</summary>
73+
public static readonly Float4E2M1 Zero = new Float4E2M1(0x0);
74+
75+
/// <summary>The value one (exp=1, mant=0 -&gt; 0x2).</summary>
76+
public static readonly Float4E2M1 One = new Float4E2M1(0x2);
77+
78+
/// <summary>The smallest positive value (0.5 subnormal, 0x1).</summary>
79+
public static readonly Float4E2M1 Epsilon = new Float4E2M1(0x1);
80+
81+
/// <summary>The largest finite value (6.0, 0x7). E2M1 has no Inf.</summary>
82+
public static readonly Float4E2M1 MaxValue = new Float4E2M1(0x7);
83+
84+
/// <summary>The smallest finite value (-6.0, 0xF).</summary>
85+
public static readonly Float4E2M1 MinValue = new Float4E2M1(0xF);
86+
87+
#endregion
88+
89+
#region Instance
90+
91+
/// <summary>Constructs a new E2M1 value from its raw 4-bit pattern (low nibble of the byte).</summary>
92+
internal Float4E2M1(byte rawValue)
93+
{
94+
RawValue = (byte)(rawValue & 0x0F);
95+
}
96+
97+
#endregion
98+
99+
#region Properties
100+
101+
/// <summary>The raw 4-bit value (stored in the low nibble of a byte).</summary>
102+
#if !DEBUG
103+
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
104+
#endif
105+
internal byte RawValue { get; }
106+
107+
#endregion
108+
109+
#region IEquatable / IComparable / Object
110+
111+
/// <summary>Returns true if the given E2M1 is equal to the current value (by float value).</summary>
112+
public readonly bool Equals(Float4E2M1 other) => (float)this == other;
113+
114+
/// <summary>Compares this E2M1 value to the given one (by float value).</summary>
115+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
116+
public readonly int CompareTo(Float4E2M1 other) => ((float)this).CompareTo(other);
117+
118+
/// <summary>Returns true if the given object is equal to the current value.</summary>
119+
public readonly override bool Equals(object? obj) =>
120+
obj is Float4E2M1 value && Equals(value);
121+
122+
/// <summary>Returns the hash code of this value.</summary>
123+
public readonly override int GetHashCode() => RawValue;
124+
125+
/// <summary>Returns the string representation of this value.</summary>
126+
public readonly override string ToString() => ((float)this).ToString();
127+
128+
#endregion
129+
130+
#region Operators
131+
132+
/// <summary>Negates the given E2M1 value (flip the sign bit).</summary>
133+
[MathIntrinsic(MathIntrinsicKind.Neg)]
134+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
135+
public static Float4E2M1 operator -(Float4E2M1 value) => Float4E2M1Extensions.Neg(value);
136+
137+
/// <summary>Adds two E2M1 values.</summary>
138+
[MathIntrinsic(MathIntrinsicKind.Add)]
139+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
140+
public static Float4E2M1 operator +(Float4E2M1 first, Float4E2M1 second) =>
141+
(Float4E2M1)((float)first + second);
142+
143+
/// <summary>Subtracts two E2M1 values.</summary>
144+
[MathIntrinsic(MathIntrinsicKind.Sub)]
145+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
146+
public static Float4E2M1 operator -(Float4E2M1 first, Float4E2M1 second) =>
147+
(Float4E2M1)((float)first - second);
148+
149+
/// <summary>Multiplies two E2M1 values.</summary>
150+
[MathIntrinsic(MathIntrinsicKind.Mul)]
151+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
152+
public static Float4E2M1 operator *(Float4E2M1 first, Float4E2M1 second) =>
153+
(Float4E2M1)((float)first * second);
154+
155+
/// <summary>Divides two E2M1 values.</summary>
156+
[MathIntrinsic(MathIntrinsicKind.Div)]
157+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
158+
public static Float4E2M1 operator /(Float4E2M1 first, Float4E2M1 second) =>
159+
(Float4E2M1)((float)first / second);
160+
161+
/// <summary>Returns true if the two values are equal.</summary>
162+
[CompareIntrinisc(CompareKind.Equal)]
163+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
164+
public static bool operator ==(Float4E2M1 first, Float4E2M1 second) =>
165+
(float)first == second;
166+
167+
/// <summary>Returns true if the two values are not equal.</summary>
168+
[CompareIntrinisc(CompareKind.NotEqual, CompareFlags.UnsignedOrUnordered)]
169+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
170+
public static bool operator !=(Float4E2M1 first, Float4E2M1 second) =>
171+
(float)first != second;
172+
173+
/// <summary>Returns true if the first value is smaller than the second.</summary>
174+
[CompareIntrinisc(CompareKind.LessThan)]
175+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
176+
public static bool operator <(Float4E2M1 first, Float4E2M1 second) =>
177+
(float)first < second;
178+
179+
/// <summary>Returns true if the first value is smaller than or equal to the second.</summary>
180+
[CompareIntrinisc(CompareKind.LessEqual)]
181+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
182+
public static bool operator <=(Float4E2M1 first, Float4E2M1 second) =>
183+
(float)first <= second;
184+
185+
/// <summary>Returns true if the first value is greater than the second.</summary>
186+
[CompareIntrinisc(CompareKind.GreaterThan)]
187+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
188+
public static bool operator >(Float4E2M1 first, Float4E2M1 second) =>
189+
(float)first > second;
190+
191+
/// <summary>Returns true if the first value is greater than or equal to the second.</summary>
192+
[CompareIntrinisc(CompareKind.GreaterEqual)]
193+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
194+
public static bool operator >=(Float4E2M1 first, Float4E2M1 second) =>
195+
(float)first >= second;
196+
197+
/// <summary>Implicitly converts an E2M1 to a float.</summary>
198+
[ConvertIntrinisc]
199+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
200+
public static implicit operator float(Float4E2M1 value) =>
201+
Float4E2M1Extensions.ConvertFloat4E2M1ToFloat(value);
202+
203+
/// <summary>Implicitly converts an E2M1 to a double.</summary>
204+
[ConvertIntrinisc]
205+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
206+
public static implicit operator double(Float4E2M1 value) => (float)value;
207+
208+
/// <summary>Explicitly converts a float to an E2M1.</summary>
209+
[ConvertIntrinisc]
210+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
211+
public static explicit operator Float4E2M1(float value) =>
212+
Float4E2M1Extensions.ConvertFloatToFloat4E2M1(value);
213+
214+
/// <summary>Explicitly converts a double to an E2M1.</summary>
215+
[ConvertIntrinisc]
216+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
217+
public static explicit operator Float4E2M1(double value) =>
218+
(Float4E2M1)(float)value;
219+
220+
#endregion
221+
}
222+
223+
/// <summary>
224+
/// Extension/implementation methods for the <see cref="Float4E2M1"/> type.
225+
/// </summary>
226+
public static partial class Float4E2M1Extensions
227+
{
228+
#region Constants
229+
230+
private const byte SignBitMask = 0x8; // bit 3
231+
private const byte MagnitudeMask = 0x7; // exp(2) + mantissa(1)
232+
private const byte MaxFiniteMagnitude = 0x7; // 6.0
233+
234+
#endregion
235+
236+
#region Conversion
237+
238+
/// <summary>Converts an E2M1 value to a float (rebias 1 -&gt; 127; 1 mantissa bit; no Inf/NaN).</summary>
239+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
240+
public static float ConvertFloat4E2M1ToFloat(Float4E2M1 value)
241+
{
242+
uint code = value.RawValue;
243+
uint sign = (code & 0x8u) << 28; // f32 sign bit (bit 31)
244+
uint e = (code >> 1) & 0x3u; // exponent field (2 bits)
245+
uint m = code & 0x1u; // mantissa (1 bit)
246+
247+
if (e == 0u)
248+
{
249+
if (m == 0u)
250+
return Interop.IntAsFloat(sign); // +-0
251+
// Subnormal 0.5 = 2^-1: f32 exponent 126, mantissa 0.
252+
return Interop.IntAsFloat(sign | (126u << 23));
253+
}
254+
// Normal: value = 1.m * 2^(e-1). f32 exp = (e-1)+127; the single mantissa bit -> bit 22.
255+
uint f32Exp = e - 1u + 127u;
256+
return Interop.IntAsFloat(sign | (f32Exp << 23) | (m << 22));
257+
}
258+
259+
/// <summary>
260+
/// Converts a float to an E2M1 value using round-to-nearest-even. Finite overflow and +-Inf
261+
/// SATURATE to +-6; NaN -&gt; -0 (0x8). Bit-exact to ml_dtypes float4_e2m1fn.
262+
/// </summary>
263+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
264+
public static Float4E2M1 ConvertFloatToFloat4E2M1(float value)
265+
{
266+
uint bits = Interop.FloatAsInt(value);
267+
uint sign = (bits >> 28) & 0x8u; // E2M1 sign bit (bit 3)
268+
uint rest = bits & 0x7FFFFFFFu;
269+
270+
// NaN -> 0x8 (-0); the format has no NaN. (ml_dtypes convention.)
271+
if (rest > 0x7F800000u)
272+
return new Float4E2M1((byte)0x8);
273+
// +-Inf -> saturate to +-6.
274+
if (rest >= 0x7F800000u)
275+
return new Float4E2M1((byte)(sign | MaxFiniteMagnitude));
276+
277+
int f32Exp = (int)((rest >> 23) & 0xFFu);
278+
uint f32Mant = rest & 0x7FFFFFu;
279+
int e = f32Exp - 127; // unbiased
280+
281+
// E2M1 normal exponent range: 0..2 (bias 1); max finite 6 at e=2, mant=1.
282+
// Finite overflow (e>2, or e==2 with mantissa rounding past 1.5) -> saturate to +-6.
283+
if (e > 2)
284+
return new Float4E2M1((byte)(sign | MaxFiniteMagnitude));
285+
286+
if (e < 0)
287+
{
288+
// Subnormal (only 0.5 = 2^-1) or zero. signif = 1.f32Mant scaled to the 0.5 grid.
289+
if (f32Exp == 0)
290+
return new Float4E2M1((byte)sign); // f32 zero/subnormal -> +-0
291+
uint signif = f32Mant | 0x800000u; // implicit 1 (24-bit)
292+
// Smallest E2M1 step is 0.5 = 2^-1. value = signif * 2^(e-23); we want round(value / 2^-1)
293+
// = round(signif * 2^(e-23+1)) as the count of 0.5 units, clamped to {0,1} (mantissa bit).
294+
int shift = (-1 - e) + 23; // align signif to the 2^-1 unit
295+
if (shift > 31)
296+
return new Float4E2M1((byte)sign); // underflow -> +-0
297+
uint q = signif >> shift; // integer count of 0.5 units
298+
uint roundBit = (signif >> (shift - 1)) & 1u;
299+
uint sticky = (signif & ((1u << (shift - 1)) - 1u)) != 0u ? 1u : 0u;
300+
if (roundBit == 1u && (sticky == 1u || (q & 1u) == 1u))
301+
q += 1u; // RNE; q may carry 0->1 or 1->2(=1.0)
302+
// q==0 -> 0; q==1 -> 0.5 (0x1); q==2 -> 1.0 (0x2 = smallest normal).
303+
return new Float4E2M1((byte)(sign | (q & 0x7u)));
304+
}
305+
306+
// Normal range (e in 0..2). Rebias and round the mantissa 23 -> 1 bit (RNE).
307+
uint mant1 = f32Mant >> 22; // top mantissa bit
308+
uint round = (f32Mant >> 21) & 1u; // first dropped bit
309+
uint stick = (f32Mant & 0x1FFFFFu) != 0u ? 1u : 0u;
310+
uint eField = (uint)(e + 1); // bias 1
311+
uint outBits = (eField << 1) | mant1;
312+
if (round == 1u && (stick == 1u || (mant1 & 1u) == 1u))
313+
outBits += 1u; // ties-to-even; may carry into the exponent
314+
// A carry past the max magnitude (0x7 = 6) saturates to 6 (no larger finite, no Inf).
315+
if ((outBits & 0x7u) > MaxFiniteMagnitude || outBits > 0x7u)
316+
outBits = MaxFiniteMagnitude;
317+
return new Float4E2M1((byte)(sign | (outBits & 0x7u)));
318+
}
319+
320+
#endregion
321+
322+
#region Predicates
323+
324+
/// <summary>Negates the given E2M1 value (flip the sign bit).</summary>
325+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
326+
public static Float4E2M1 Neg(Float4E2M1 value) =>
327+
new Float4E2M1((byte)(value.RawValue ^ SignBitMask));
328+
329+
/// <summary>Returns the absolute value (clear the sign bit).</summary>
330+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
331+
public static Float4E2M1 Abs(Float4E2M1 value) =>
332+
new Float4E2M1((byte)(value.RawValue & MagnitudeMask));
333+
334+
/// <summary>Returns true if the value is +-0 (magnitude == 0).</summary>
335+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
336+
public static bool IsZero(Float4E2M1 value) =>
337+
(value.RawValue & MagnitudeMask) == 0;
338+
339+
#endregion
340+
341+
#region FP32 Implementation Methods
342+
343+
/// <summary>Implements an E2M1 addition using FP32.</summary>
344+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
345+
public static Float4E2M1 AddFP32(Float4E2M1 first, Float4E2M1 second) =>
346+
(Float4E2M1)((float)first + second);
347+
348+
/// <summary>Implements an E2M1 subtraction using FP32.</summary>
349+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
350+
public static Float4E2M1 SubFP32(Float4E2M1 first, Float4E2M1 second) =>
351+
(Float4E2M1)((float)first - second);
352+
353+
/// <summary>Implements an E2M1 multiplication using FP32.</summary>
354+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
355+
public static Float4E2M1 MulFP32(Float4E2M1 first, Float4E2M1 second) =>
356+
(Float4E2M1)((float)first * second);
357+
358+
/// <summary>Implements an E2M1 division using FP32.</summary>
359+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
360+
public static Float4E2M1 DivFP32(Float4E2M1 first, Float4E2M1 second) =>
361+
(Float4E2M1)((float)first / second);
362+
363+
#endregion
364+
}
365+
}

0 commit comments

Comments
 (0)