|
10 | 10 | // training recipe (E4M3 forward, E5M2 backward): it trades dynamic range for an extra mantissa |
11 | 11 | // bit vs E5M2, which is what forward activations/weights want. |
12 | 12 | // |
13 | | -// CONVENTION (flagged for ML-oracle confirmation, plan §9 risk #2 - confirm vs PyTorch |
14 | | -// float8_e4m3fn / NVIDIA Transformer Engine when wired into the ML lane): finite overflow |
15 | | -// SATURATES to +-448; a real +-Inf input maps to NaN (E4M3 has no Inf); NaN -> NaN. This |
16 | | -// matches the OCP/TE saturating-forward convention. Only the out-of-range INPUT behavior is |
17 | | -// convention-dependent; every REPRESENTABLE value round-trips exactly (verified by the CPU |
18 | | -// idempotence harness, `DemoConsole -- fp8-verify`). |
| 13 | +// OVERFLOW CONVENTION (verified vs the ml_dtypes reference, `DemoConsole -- fp8-oracle` - |
| 14 | +// ml_dtypes is the impl PyTorch / JAX float8_e4m3fn share). E4M3 has two real-world conventions |
| 15 | +// and BOTH are selectable here; the conversion is otherwise bit-exact to the reference (decode |
| 16 | +// 0/256, encode rounding/subnormal 0 divergences across 1099 probes): |
| 17 | +// * SATURATING (the bare cast operator + FromSingleSaturating): finite overflow clamps to |
| 18 | +// +-448; +-Inf -> NaN; NaN -> NaN. Matches the NVIDIA Transformer Engine default cast / |
| 19 | +// OCP saturating-forward mode. Avoids NaN propagation when activations overflow unscaled. |
| 20 | +// * fn / non-saturating (FromSingleFn): finite overflow AND +-Inf -> NaN; NaN -> NaN. Bit- |
| 21 | +// exact to PyTorch/JAX/ml_dtypes float8_e4m3fn (the dtype this layout is named after). Use |
| 22 | +// this for reference-matching ML. The two conventions agree everywhere except |x|>464 (the |
| 23 | +// region that rounds up past the 448 slot): saturating gives +-448, fn gives NaN. |
| 24 | +// Every REPRESENTABLE value round-trips exactly under both (CPU idempotence harness fp8-verify). |
19 | 25 | // |
20 | 26 | // Modeled on ILGPU.Half / BFloat16 / Float8E5M2: FP32-based [MathIntrinsic]/[CompareIntrinisc]/ |
21 | 27 | // [ConvertIntrinisc] operators (transpiled on every backend). 1-byte storage. |
@@ -61,6 +67,35 @@ namespace ILGPU |
61 | 67 | [MethodImpl(MethodImplOptions.AggressiveInlining)] |
62 | 68 | public static bool IsFinite(Float8E4M3 value) => Float8E4M3Extensions.IsFinite(value); |
63 | 69 |
|
| 70 | + /// <summary> |
| 71 | + /// Converts a float to E4M3 with a selectable overflow convention. When |
| 72 | + /// <paramref name="saturate"/> is true (the default, matching the cast operator): finite |
| 73 | + /// overflow clamps to +-448 (NVIDIA Transformer Engine / OCP saturating cast). When false: |
| 74 | + /// finite overflow and +-Inf map to NaN, bit-exact to PyTorch/JAX/ml_dtypes float8_e4m3fn. |
| 75 | + /// </summary> |
| 76 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 77 | + public static Float8E4M3 FromSingle(float value, bool saturate) => |
| 78 | + saturate ? Float8E4M3Extensions.ConvertFloatToFloat8E4M3(value) |
| 79 | + : Float8E4M3Extensions.FromSingleFn(value); |
| 80 | + |
| 81 | + /// <summary> |
| 82 | + /// Converts a float to E4M3 using the SATURATING convention: finite overflow clamps to |
| 83 | + /// +-448; +-Inf -> NaN; NaN -> NaN. Identical to the explicit cast operator. Matches the |
| 84 | + /// NVIDIA Transformer Engine default cast / OCP saturating-forward mode. |
| 85 | + /// </summary> |
| 86 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 87 | + public static Float8E4M3 FromSingleSaturating(float value) => |
| 88 | + Float8E4M3Extensions.ConvertFloatToFloat8E4M3(value); |
| 89 | + |
| 90 | + /// <summary> |
| 91 | + /// Converts a float to E4M3 using the fn (non-saturating) convention: finite overflow AND |
| 92 | + /// +-Inf map to NaN; NaN -> NaN. Bit-exact to PyTorch/JAX/ml_dtypes float8_e4m3fn - use |
| 93 | + /// this for reference-matching ML. Differs from the saturating cast only for |value|>464. |
| 94 | + /// </summary> |
| 95 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 96 | + public static Float8E4M3 FromSingleFn(float value) => |
| 97 | + Float8E4M3Extensions.FromSingleFn(value); |
| 98 | + |
64 | 99 | #endregion |
65 | 100 |
|
66 | 101 | #region Constants |
@@ -324,6 +359,27 @@ public static Float8E4M3 ConvertFloatToFloat8E4M3(float value) |
324 | 359 | return new Float8E4M3((byte)(sign | (outBits & 0x7Fu))); |
325 | 360 | } |
326 | 361 |
|
| 362 | + /// <summary> |
| 363 | + /// Converts a float to E4M3 using the fn (float8_e4m3fn) convention: finite overflow and |
| 364 | + /// +-Inf map to NaN (NOT saturation); NaN -> NaN. Bit-exact to PyTorch / JAX / ml_dtypes |
| 365 | + /// (verified, <c>DemoConsole -- fp8-oracle</c>). Composed only of existing intrinsics |
| 366 | + /// (compare, the saturating cast, Neg, cast-of-NaN) so it transpiles on every backend with |
| 367 | + /// no per-backend conversion codegen. |
| 368 | + /// </summary> |
| 369 | + [MethodImpl(MethodImplOptions.AggressiveInlining)] |
| 370 | + public static Float8E4M3 FromSingleFn(float value) |
| 371 | + { |
| 372 | + // |value| <= 464 rounds to <= 448 (bit-exact to the reference) via the saturating base |
| 373 | + // convert; |value| > 464 is the round-up-past-448 region -> NaN. A NaN input fails both |
| 374 | + // ordered compares (NaN > / < are false) and falls through to the base convert, which |
| 375 | + // already maps NaN -> NaN. +-Inf trip the compares -> signed NaN. |
| 376 | + if (value > 464.0f) |
| 377 | + return (Float8E4M3)float.NaN; // +overflow / +Inf -> +NaN (0x7F) |
| 378 | + if (value < -464.0f) |
| 379 | + return -(Float8E4M3)float.NaN; // -overflow / -Inf -> -NaN (0xFF) |
| 380 | + return (Float8E4M3)value; |
| 381 | + } |
| 382 | + |
327 | 383 | #endregion |
328 | 384 |
|
329 | 385 | #region Predicates |
|
0 commit comments