Skip to content

Commit db52337

Browse files
LostBeardclaude
andcommitted
FP4 (Float4E2M1) radix-sort keys on all 6 backends + latent PTX struct-field IO fix - 4.14.0-local.6
Closes the FP4 radix follow-up from local.5. Float4E2M1 arrays radix-sort (keys-only + pairs, ascending + descending) on CPU/CUDA/OpenCL/WebGPU/WebGL/Wasm, mirroring the FP8 radix work (e815bb7). Four-package bundle: forks 2.0.32 -> 2.0.33, SpawnDev.ILGPU local.5 -> local.6. - Ascending/DescendingFloat4E2M1 (RadixSortOperations.Float4E2M1.cs): the Half/bf16/FP8 sign-flip + ones-complement key transform adapted to E2M1 (sign at BIT 3, magnitude in the low 3 bits, masks hardcoded to the nibble). FP4 is byte-stored (value in low nibble), so it sorts as a native 1-byte key (NumBits=8; the key is 0..15 so it stays monotonic over the byte) on 5 backends; WebGL uses the unpacked-f32 working representation (RadixSortExtensions.Float4E2M1.cs + dispatch arms in RadixSortExtensions.cs), like Half/bf16/FP8 - the whole-texel scatter can't move a sub-word value. - Per-backend FloatAsInt(Float4E2M1) radix codegen completed: PTX + Wasm FloatAsIntCast/IntAsFloatCast for FP4 (the convert release wired OpenCL/WGSL/GLSL; PTX/Wasm here) recover the 4-bit pattern via the portable EmitF32ToFP4Bits / EmitFP4ToF32 helpers. - LATENT BUG FIXED (Rule 2a): PTX EmitIOLoad/EmitIOStore (the struct-field path the RadixSortPairs kernel uses to bundle the 1-byte key with the value) handled bf16 + FP8 but NOT FP4, so an FP4 key field was stored as the f32 register's raw low byte (= 0 for most values) -> CUDA FP4 pairs returned ALL-ZERO keys. Added FP4 to both (EmitF32ToFP4Bits on store, EmitFP4BitsToF32 on load). Root-caused with a desktop repro printing the actual CUDA sorted output (keys-only worked; only the key-bundle pairs path was wrong - same class of bug bf16 needed its EmitIO fix for). Gates: PMT Fp4Radix (ExtractBits GPU-vs-CPU + KeysDescending + PairsAscending) 23/0 ALL 6 backends; no regression (Fp8Radix 44/0, Float4E2M1 convert 23/0). NEXT: INT4, then MXFP4/NF4 dequant (ML lane). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 48c9dc0 commit db52337

11 files changed

Lines changed: 565 additions & 5 deletions

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
This file tracks notable changes per release. The README's "Recent Highlights" section links here for the full version history.
44

5+
## 4.14.0-local.6 (2026-06-17) - `Float4E2M1` radix-sort keys on all 6 backends + a latent PTX struct-field IO bug fix
6+
7+
Closes the FP4 follow-up from local.5: `Float4E2M1` arrays can now be radix-sorted (keys-only + pairs, ascending + descending) on CPU/CUDA/OpenCL/WebGPU/WebGL/Wasm. Forks bump to `2.0.33`.
8+
9+
- **`Ascending/DescendingFloat4E2M1` radix operations.** The same sign-flip + ones-complement key transform Half/bf16/FP8 use, adapted to the 4-bit E2M1 layout (sign at **bit 3**, not the top bit; magnitude in the low 3 bits). FP4 is stored as a 1-byte element (value in the low nibble), so it sorts as a native 1-byte key (NumBits=8, the key is 0..15 so it stays monotonic) on 5 backends; WebGL uses the unpacked-f32 working representation (the whole-texel scatter can't move a 1-byte sub-word), like Half/bf16/FP8.
10+
- **Per-backend `FloatAsInt(Float4E2M1)` radix codegen** completed: the PTX + Wasm `FloatAsIntCast`/`IntAsFloatCast` for FP4 (the convert release wired OpenCL/WGSL/GLSL; PTX/Wasm got it here) recover the 4-bit pattern via the portable bit-manip helpers.
11+
- **Fixed a real latent PTX struct-field IO bug (Rule 2a).** The PTX `EmitIOLoad`/`EmitIOStore` (the path the `RadixSortPairs` kernel uses to bundle the 1-byte key with the value) handled bf16 + FP8 but not FP4, so an FP4 key field was stored as the f32 register's raw low byte (= 0 for most values) → CUDA FP4 pairs returned **all-zero keys**. Added FP4 to both (round f32 → the 4-bit pattern via `EmitF32ToFP4Bits` on store, widen via `EmitFP4BitsToF32` on load). Root-caused with a desktop repro printing the actual CUDA sorted output (keys-only worked; only the key-bundle pairs path was wrong).
12+
- Gates: PMT `Fp4Radix` (ExtractBits GPU-vs-CPU + KeysDescending + PairsAscending) **23/0 all 6 backends**; no regression (`Fp8Radix` 44/0, `Float4E2M1` convert 23/0).
13+
514
## 4.14.0-local.5 (2026-06-17) - New 4-bit float type `Float4E2M1` (NVFP4/MXFP4 element format) on all 6 backends + a latent low-precision store-widening bug fix
615

716
Adds `ILGPU.Float4E2M1`, the OCP **E2M1FN** 4-bit float (the element format of NVFP4 / MXFP4): 1 sign / 2 exp / 1 mantissa, bias 1, **16 finite codes (no Inf, no NaN)**, magnitudes `{0,.5,1,1.5,2,3,4,6}`, max 6, finite overflow + ±Inf saturate to ±6, NaN→-0. 1-byte storage (value in the low nibble), f32-register compute. Forks bump to `2.0.32`. Bit-exact to `ml_dtypes.float4_e2m1fn` (PyTorch/JAX share it).

ILGPU.Algorithms/ILGPU.Algorithms.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
SpawnDev.ILGPU.Fork* PackageReference Versions inside SpawnDev.ILGPU.csproj.
1313
Run `_check-fork-version-sync.bat` at repo root. See the banner comment in
1414
SpawnDev.ILGPU.csproj for the full procedure. -->
15-
<Version>2.0.32</Version>
15+
<Version>2.0.33</Version>
1616
<IsPackable>true</IsPackable>
1717
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
1818
</PropertyGroup>
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// ---------------------------------------------------------------------------------------
2+
// ILGPU Algorithms
3+
// Copyright (c) 2019-2023 ILGPU Project
4+
// www.ilgpu.net
5+
//
6+
// File: RadixSortExtensions.Float4E2M1.cs
7+
//
8+
// This file is part of ILGPU and is distributed under the University of Illinois Open
9+
// Source License. See LICENSE.txt for details.
10+
// ---------------------------------------------------------------------------------------
11+
12+
using ILGPU.Algorithms.RadixSortOperations;
13+
using ILGPU.Algorithms.ScanReduceOperations;
14+
using ILGPU.Runtime;
15+
using System.Runtime.CompilerServices;
16+
17+
namespace ILGPU.Algorithms
18+
{
19+
// WebGL Float4E2M1-key radix sort. FP4 is a 1-byte sub-word key (value in the low
20+
// nibble); the WebGL render-to-texture scatter writes WHOLE 32-bit texels and cannot
21+
// move a sub-texel value, so - exactly like Half, BFloat16 and FP8 (RadixSortExtensions.cs
22+
// / RadixSortExtensions.BFloat16.cs / RadixSortExtensions.Float8E4M3.cs) - FP4 sorts via
23+
// an UNPACKED f32 working representation: copy-in widens each Float4E2M1 to f32 (lossless:
24+
// every one of the 16 finite FP4 codes is a strict subset of f32), the radix bit is derived
25+
// by narrowing back to Float4E2M1 and calling the canonical ExtractRadixBits, and copy-out
26+
// narrows the sorted f32 back to Float4E2M1 (exact round-trip for any value that began as a
27+
// Float4E2M1). Mirrors the FP8 path one-for-one; the only difference is NumBits == 4 (the
28+
// FP4 value occupies the low nibble) rather than 8, so the per-bit loop runs 4 passes.
29+
static partial class RadixSortExtensions
30+
{
31+
private static void WebGLScatterRadixCopyInFloat4E2M1<TStride>(
32+
Index1D index, ArrayView1D<Float4E2M1, TStride> input,
33+
ArrayView1D<float, Stride1D.Dense> output)
34+
where TStride : struct, IStride1D =>
35+
output[index.X] = (float)input[index.X];
36+
37+
private static void WebGLScatterRadixCopyOutFloat4E2M1<TStride>(
38+
Index1D index, ArrayView1D<float, Stride1D.Dense> input,
39+
ArrayView1D<Float4E2M1, TStride> output)
40+
where TStride : struct, IStride1D =>
41+
output[index.X] = (Float4E2M1)input[index.X];
42+
43+
private static void WebGLScatterRadixExtractBitFloat4E2M1<TRadixSortOperation>(
44+
Index1D index, ArrayView1D<float, Stride1D.Dense> keys,
45+
ArrayView1D<int, Stride1D.Dense> flags, int bit)
46+
where TRadixSortOperation : struct, IRadixSortOperation<Float4E2M1>
47+
{
48+
TRadixSortOperation op = default;
49+
flags[index.X] = op.ExtractRadixBits((Float4E2M1)keys[index.X], bit, 1);
50+
}
51+
52+
// Keys-only Float4E2M1 sort. Invoked by reflection from CreateRadixSort (the outer
53+
// method is generic on T; the compiler can't see T == Float4E2M1 to bind the
54+
// IRadixSortOperation<Float4E2M1> constraint statically). Called once per handler.
55+
private static RadixSort<Float4E2M1, TStride> CreateWebGLScatterRadixSortFloat4E2M1<
56+
TStride, TRadixSortOperation>(Accelerator accelerator, IScatterProvider scatter)
57+
where TStride : struct, IStride1D
58+
where TRadixSortOperation : struct, IRadixSortOperation<Float4E2M1>
59+
{
60+
var copyIn = accelerator.LoadAutoGroupedKernel<
61+
Index1D, ArrayView1D<Float4E2M1, TStride>, ArrayView1D<float, Stride1D.Dense>>(
62+
WebGLScatterRadixCopyInFloat4E2M1<TStride>);
63+
var copyOut = accelerator.LoadAutoGroupedKernel<
64+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<Float4E2M1, TStride>>(
65+
WebGLScatterRadixCopyOutFloat4E2M1<TStride>);
66+
var extractBit = accelerator.LoadAutoGroupedKernel<
67+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>, int>(
68+
WebGLScatterRadixExtractBitFloat4E2M1<TRadixSortOperation>);
69+
var computeDest = accelerator.LoadAutoGroupedKernel<
70+
Index1D, ArrayView1D<int, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>,
71+
ArrayView1D<int, Stride1D.Dense>, int>(WebGLScatterRadixComputeDest);
72+
var exclusiveScan = accelerator.CreateScan<
73+
int, Stride1D.Dense, Stride1D.Dense, AddInt32>(ScanKind.Exclusive);
74+
75+
int numBits = default(TRadixSortOperation).NumBits; // 4
76+
77+
return (stream, view, temp) =>
78+
{
79+
int n = (int)view.Length;
80+
if (n <= 1)
81+
return;
82+
83+
using var keysA = accelerator.Allocate1D<float>(n);
84+
using var keysB = accelerator.Allocate1D<float>(n);
85+
using var flags = accelerator.Allocate1D<int>(n);
86+
using var onePrefix = accelerator.Allocate1D<int>(n);
87+
using var dest = accelerator.Allocate1D<int>(n);
88+
using var scanTemp = accelerator.Allocate1D<int>(1);
89+
90+
copyIn(stream, n, view, keysA.View);
91+
92+
var src = keysA;
93+
var dst = keysB;
94+
for (int bit = 0; bit < numBits; bit++)
95+
{
96+
extractBit(stream, n, src.View, flags.View, bit);
97+
exclusiveScan(stream, flags.View, onePrefix.View, scanTemp.View);
98+
computeDest(stream, n, flags.View, onePrefix.View, dest.View, n);
99+
scatter.Scatter(dst.View, src.View, dest.View, n, "float");
100+
var tmp = src; src = dst; dst = tmp;
101+
}
102+
103+
copyOut(stream, n, src.View, view);
104+
};
105+
}
106+
107+
108+
// Float4E2M1-KEY pairs sort (FP4 key + any 4/8-byte non-FP4 value). Keys use the
109+
// unpacked f32 working representation; values use the same int/float/uint scatter
110+
// program as the generic pairs path. Invoked by reflection from CreateRadixSortPairs.
111+
private static RadixSortPairs<Float4E2M1, TKeyStride, TValue, TValueStride>
112+
CreateWebGLScatterRadixSortPairsFloat4E2M1Key<
113+
TKeyStride, TValue, TValueStride, TRadixSortOperation>(
114+
Accelerator accelerator, IScatterProvider scatter)
115+
where TKeyStride : struct, IStride1D
116+
where TValue : unmanaged
117+
where TValueStride : struct, IStride1D
118+
where TRadixSortOperation : struct, IRadixSortOperation<Float4E2M1>
119+
{
120+
var copyInKeys = accelerator.LoadAutoGroupedKernel<
121+
Index1D, ArrayView1D<Float4E2M1, TKeyStride>, ArrayView1D<float, Stride1D.Dense>>(
122+
WebGLScatterRadixCopyInFloat4E2M1<TKeyStride>);
123+
var copyOutKeys = accelerator.LoadAutoGroupedKernel<
124+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<Float4E2M1, TKeyStride>>(
125+
WebGLScatterRadixCopyOutFloat4E2M1<TKeyStride>);
126+
var copyInVals = accelerator.LoadAutoGroupedKernel<
127+
Index1D, ArrayView1D<TValue, TValueStride>, ArrayView1D<TValue, Stride1D.Dense>>(
128+
WebGLScatterRadixCopyIn<TValue, TValueStride>);
129+
var copyOutVals = accelerator.LoadAutoGroupedKernel<
130+
Index1D, ArrayView1D<TValue, Stride1D.Dense>, ArrayView1D<TValue, TValueStride>>(
131+
WebGLScatterRadixCopyOut<TValue, TValueStride>);
132+
var extractBit = accelerator.LoadAutoGroupedKernel<
133+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>, int>(
134+
WebGLScatterRadixExtractBitFloat4E2M1<TRadixSortOperation>);
135+
var computeDest = accelerator.LoadAutoGroupedKernel<
136+
Index1D, ArrayView1D<int, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>,
137+
ArrayView1D<int, Stride1D.Dense>, int>(WebGLScatterRadixComputeDest);
138+
var exclusiveScan = accelerator.CreateScan<
139+
int, Stride1D.Dense, Stride1D.Dense, AddInt32>(ScanKind.Exclusive);
140+
141+
int numBits = default(TRadixSortOperation).NumBits; // 4
142+
string valType = WebGLScatterValueType<TValue>();
143+
int valCpe = WebGLScatterCpe<TValue>();
144+
145+
return (stream, keys, values, tempView) =>
146+
{
147+
int n = (int)keys.Length;
148+
if (n <= 1)
149+
return;
150+
151+
using var keysA = accelerator.Allocate1D<float>(n);
152+
using var keysB = accelerator.Allocate1D<float>(n);
153+
using var valsA = accelerator.Allocate1D<TValue>(n);
154+
using var valsB = accelerator.Allocate1D<TValue>(n);
155+
using var flags = accelerator.Allocate1D<int>(n);
156+
using var onePrefix = accelerator.Allocate1D<int>(n);
157+
using var dest = accelerator.Allocate1D<int>(n);
158+
using var scanTemp = accelerator.Allocate1D<int>(1);
159+
160+
copyInKeys(stream, n, keys, keysA.View);
161+
copyInVals(stream, n, values, valsA.View);
162+
163+
var kSrc = keysA;
164+
var kDst = keysB;
165+
var vSrc = valsA;
166+
var vDst = valsB;
167+
for (int bit = 0; bit < numBits; bit++)
168+
{
169+
extractBit(stream, n, kSrc.View, flags.View, bit);
170+
exclusiveScan(stream, flags.View, onePrefix.View, scanTemp.View);
171+
computeDest(stream, n, flags.View, onePrefix.View, dest.View, n);
172+
scatter.Scatter(kDst.View, kSrc.View, dest.View, n, "float", 1);
173+
scatter.Scatter(vDst.View, vSrc.View, dest.View, n, valType, valCpe);
174+
var kt = kSrc; kSrc = kDst; kDst = kt;
175+
var vt = vSrc; vSrc = vDst; vDst = vt;
176+
}
177+
178+
copyOutKeys(stream, n, kSrc.View, keys);
179+
copyOutVals(stream, n, vSrc.View, values);
180+
};
181+
}
182+
}
183+
}

ILGPU.Algorithms/RadixSortExtensions.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,23 @@ accelerator is IScatterProvider scatterProviderE5M2Key &&
12101210
return (RadixSortPairs<TKey, TKeyStride, TValue, TValueStride>)handler;
12111211
}
12121212

1213+
// Float4E2M1 KEY (+ any 4/8-byte non-FP4 value): same sub-word unpacked-f32 path.
1214+
if (accelerator.AcceleratorType == AcceleratorType.WebGL &&
1215+
accelerator is IScatterProvider scatterProviderE2M1Key &&
1216+
typeof(TKey) == typeof(Float4E2M1) &&
1217+
(Interop.SizeOf<TValue>() == 4 || Interop.SizeOf<TValue>() == 8) &&
1218+
typeof(TValue) != typeof(Float4E2M1))
1219+
{
1220+
var handler = typeof(RadixSortExtensions)
1221+
.GetMethod(nameof(CreateWebGLScatterRadixSortPairsFloat4E2M1Key),
1222+
BindingFlags.NonPublic | BindingFlags.Static)!
1223+
.MakeGenericMethod(
1224+
typeof(TKeyStride), typeof(TValue), typeof(TValueStride),
1225+
typeof(TRadixSortOperation))
1226+
.Invoke(null, new object[] { accelerator, scatterProviderE2M1Key })!;
1227+
return (RadixSortPairs<TKey, TKeyStride, TValue, TValueStride>)handler;
1228+
}
1229+
12131230
if (accelerator.AcceleratorType == AcceleratorType.WebGL &&
12141231
accelerator is IScatterProvider scatterProviderPairs &&
12151232
(Interop.SizeOf<TKey>() == 4 || Interop.SizeOf<TKey>() == 8) &&
@@ -1564,6 +1581,18 @@ private static RadixSort<T, TStride> CreateWebGLScatterRadixSortDispatch<
15641581
.Invoke(null, new object[] { accelerator, scatterProvider })!;
15651582
return (RadixSort<T, TStride>)handler;
15661583
}
1584+
// FP4 (E2M1FN) is a 1-byte sub-word key (value in the low nibble) - same
1585+
// unpacked-f32 working representation as Half/bf16/FP8 (every one of the 16
1586+
// finite FP4 codes is a strict subset of f32).
1587+
if (typeof(T) == typeof(Float4E2M1))
1588+
{
1589+
var handler = typeof(RadixSortExtensions)
1590+
.GetMethod(nameof(CreateWebGLScatterRadixSortFloat4E2M1),
1591+
BindingFlags.NonPublic | BindingFlags.Static)!
1592+
.MakeGenericMethod(typeof(TStride), typeof(TRadixSortOperation))
1593+
.Invoke(null, new object[] { accelerator, scatterProvider })!;
1594+
return (RadixSort<T, TStride>)handler;
1595+
}
15671596
return CreateWebGLScatterRadixSort<T, TStride, TRadixSortOperation>(
15681597
accelerator, scatterProvider);
15691598
}

0 commit comments

Comments
 (0)