Skip to content

Commit e815bb7

Browse files
LostBeardclaude
andcommitted
FP8 (Float8E4M3 / Float8E5M2) radix-sort keys on all 6 backends (4.13.1-local.1)
Closes the tracked 4.13.0 follow-up: FP8 arrays can now be radix-sorted (keys-only + pairs, ascending + descending) on CPU/CUDA/OpenCL/WebGPU/WebGL/Wasm. - Interop.FloatAsInt(Float8E4M3)/(Float8E5M2) (raw 8-bit pattern, like the Half/BFloat16 twins) + FloatAsIntCast IR lowering for FP8 (constant-fold + Int8 result sizing; per-backend codegen PTX/OpenCL/WGSL/GLSL/Wasm). - Ascending/DescendingFloat8E4M3/E5M2 radix operations (sign-flip + ones-complement float key transform at 8-bit width; both formats are magnitude-monotonic). WebGL uses the unpacked-f32 path (like Half/bf16); the other 5 backends sort native 1-byte keys. - WebGPU packed-sub-word fix: FP8 is its own BasicValueType (not Int8), so it was skipped by every Int8/Int16/BFloat16 WGSL classification switch and fell to f32 -> the key buffer was declared array<f32> and read via a raw whole-word deref, corrupting the sort (WebGPU only). Added FP8 to all four sub-word switches (body-struct binding-type + body-struct LEA + direct-param LEA + coalesce) so it is declared packed array<atomic<u32>> and extracted+converted at load/store, mirroring bf16's 2-per-word path. Localized via the Dawn dump_shaders Tint dump (PMT_DAWN_DUMP=1). Gate: new BackendTestBase.Fp8Radix_E{4M3,5M2}_{ExtractBits,KeysDescending, PairsAscending}; PMT_FILTER=Fp8Radix 36/0 across all 6 backends. No regression to bf16/Half radix. Four-package bundle: forks 2.0.26 -> 2.0.27, SpawnDev.ILGPU 4.13.0 -> 4.13.1-local.1. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent bb2c8d5 commit e815bb7

20 files changed

Lines changed: 1058 additions & 9 deletions

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
This file tracks notable changes per release. The README's "Recent Highlights" section links here for the full version history.
44

5+
## 4.13.1 (unreleased) - FP8 radix-sort keys on all 6 backends
6+
7+
### local.1 - FP8 (Float8E4M3 / Float8E5M2) radix-sort keys
8+
9+
- **FP8 arrays can now be radix-sorted on all 6 backends** (keys-only + key/value pairs, ascending + descending) - closing the tracked 4.13.0 follow-up. Added: `Interop.FloatAsInt(Float8E4M3)` / `(Float8E5M2)` (the raw 8-bit pattern, like the `Half`/`BFloat16` twins); the IR `FloatAsIntCast` lowering for FP8 across all backends (constant-fold + `Int8` result sizing in `IR/Construction/Cast.cs`; per-backend codegen on PTX `EmitF32ToFP8Bits`, OpenCL `_f32_to_e4m3_bits`, WGSL/GLSL `_f32_to_e4m3`, Wasm `EmitF32ToFP8`); and `Ascending`/`DescendingFloat8E4M3`/`E5M2` radix operations (the sign-flip + ones-complement float key transform at 8-bit width - both E4M3 and E5M2 are magnitude-monotonic, exponent above mantissa). On WebGL FP8 keys sort via the unpacked-f32 working representation (same as Half/bf16, since the whole-texel scatter can't move a sub-word value); on the other 5 backends as native 1-byte keys.
10+
- **WebGPU packed-sub-word fix (the hard part).** `Float8E4M3`/`Float8E5M2` are their OWN `BasicValueType` (NOT `Int8`), so they were silently skipped by every `case Int8/Int16/BFloat16` switch in the WGSL codegen and fell to a default that maps FP8 -> `f32`. For a packed FP8 key buffer this meant: the binding was declared `array<f32>` instead of `array<atomic<u32>>`, and the kernel read each key via a raw whole-word deref instead of a 4-per-word byte extract + `_e4m3_to_f32` - so the radix sort read garbage and corrupted the result (WebGPU only; the 5 other backends were correct). Fixed by adding FP8 to all four WGSL sub-word classification switches (body-struct binding-type, body-struct LEA, direct-param LEA, direct-param coalesce) so FP8 is declared packed `array<atomic<u32>>` and extracted+converted at load/store - exactly the path bf16 (2-per-word) already used. Localized with the Dawn `dump_shaders` Tint-output dump (`PMT_DAWN_DUMP=1`), not by staring at the WGSL.
11+
- Gate: new `BackendTestBase.Fp8Radix_E{4M3,5M2}_{ExtractBits,KeysDescending,PairsAscending}` - GPU-vs-CPU radix-bucket compare + keys-only descending (tiled to a multi-group size) + key/value pairs ascending, over distinct exactly-representable FP8 values spanning negative..positive. **`PMT_FILTER=Fp8Radix` 36/0 across all 6 backends** (CPU + CUDA + OpenCL + WebGPU + WebGPU-NoSubgroups + WebGL + Wasm). No regression to bf16/Half radix.
12+
513
## 4.13.0 (2026-06-16) - Low-precision floats on all 6 backends: BFloat16 + FP8 (Float8E4M3 / Float8E5M2), generic INumber<T> mixed-precision kernels, PrecisionConvert, and bf16/FP8 portability to pre-Ampere CUDA cards
614

715
> 4.13.0 was developed across the local.5 -> local.10 series; the dated headline above is the stable cut. Per-milestone detail follows.

ILGPU.Algorithms/ILGPU.Algorithms.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
SpawnDev.ILGPU.Fork* PackageReference Versions inside SpawnDev.ILGPU.csproj.
1313
Run `_check-fork-version-sync.bat` at repo root. See the banner comment in
1414
SpawnDev.ILGPU.csproj for the full procedure. -->
15-
<Version>2.0.26</Version>
15+
<Version>2.0.27</Version>
1616
<IsPackable>true</IsPackable>
1717
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
1818
</PropertyGroup>
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// ---------------------------------------------------------------------------------------
2+
// ILGPU Algorithms
3+
// Copyright (c) 2019-2023 ILGPU Project
4+
// www.ilgpu.net
5+
//
6+
// File: RadixSortExtensions.Float8E4M3.cs
7+
//
8+
// This file is part of ILGPU and is distributed under the University of Illinois Open
9+
// Source License. See LICENSE.txt for details.
10+
// ---------------------------------------------------------------------------------------
11+
12+
using ILGPU.Algorithms.RadixSortOperations;
13+
using ILGPU.Algorithms.ScanReduceOperations;
14+
using ILGPU.Runtime;
15+
using System.Runtime.CompilerServices;
16+
17+
namespace ILGPU.Algorithms
18+
{
19+
// WebGL Float8E4M3-key radix sort. FP8 is a 1-byte sub-word key; the WebGL
20+
// render-to-texture scatter writes WHOLE 32-bit texels and cannot move a sub-texel
21+
// value, so - exactly like Half and BFloat16 (RadixSortExtensions.cs /
22+
// RadixSortExtensions.BFloat16.cs) - FP8 sorts via an UNPACKED f32 working
23+
// representation: copy-in widens each Float8E4M3 to f32 (lossless: every FP8 value is a
24+
// strict subset of f32), the radix bit is derived by narrowing back to Float8E4M3 and
25+
// calling the canonical ExtractRadixBits, and copy-out narrows the sorted f32 back to
26+
// Float8E4M3 (exact round-trip for any value that began as a Float8E4M3). Mirrors the
27+
// BFloat16 path one-for-one.
28+
static partial class RadixSortExtensions
29+
{
30+
private static void WebGLScatterRadixCopyInFloat8E4M3<TStride>(
31+
Index1D index, ArrayView1D<Float8E4M3, TStride> input,
32+
ArrayView1D<float, Stride1D.Dense> output)
33+
where TStride : struct, IStride1D =>
34+
output[index.X] = (float)input[index.X];
35+
36+
private static void WebGLScatterRadixCopyOutFloat8E4M3<TStride>(
37+
Index1D index, ArrayView1D<float, Stride1D.Dense> input,
38+
ArrayView1D<Float8E4M3, TStride> output)
39+
where TStride : struct, IStride1D =>
40+
output[index.X] = (Float8E4M3)input[index.X];
41+
42+
private static void WebGLScatterRadixExtractBitFloat8E4M3<TRadixSortOperation>(
43+
Index1D index, ArrayView1D<float, Stride1D.Dense> keys,
44+
ArrayView1D<int, Stride1D.Dense> flags, int bit)
45+
where TRadixSortOperation : struct, IRadixSortOperation<Float8E4M3>
46+
{
47+
TRadixSortOperation op = default;
48+
flags[index.X] = op.ExtractRadixBits((Float8E4M3)keys[index.X], bit, 1);
49+
}
50+
51+
// Keys-only Float8E4M3 sort. Invoked by reflection from CreateRadixSort (the outer
52+
// method is generic on T; the compiler can't see T == Float8E4M3 to bind the
53+
// IRadixSortOperation<Float8E4M3> constraint statically). Called once per handler.
54+
private static RadixSort<Float8E4M3, TStride> CreateWebGLScatterRadixSortFloat8E4M3<
55+
TStride, TRadixSortOperation>(Accelerator accelerator, IScatterProvider scatter)
56+
where TStride : struct, IStride1D
57+
where TRadixSortOperation : struct, IRadixSortOperation<Float8E4M3>
58+
{
59+
var copyIn = accelerator.LoadAutoGroupedKernel<
60+
Index1D, ArrayView1D<Float8E4M3, TStride>, ArrayView1D<float, Stride1D.Dense>>(
61+
WebGLScatterRadixCopyInFloat8E4M3<TStride>);
62+
var copyOut = accelerator.LoadAutoGroupedKernel<
63+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<Float8E4M3, TStride>>(
64+
WebGLScatterRadixCopyOutFloat8E4M3<TStride>);
65+
var extractBit = accelerator.LoadAutoGroupedKernel<
66+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>, int>(
67+
WebGLScatterRadixExtractBitFloat8E4M3<TRadixSortOperation>);
68+
var computeDest = accelerator.LoadAutoGroupedKernel<
69+
Index1D, ArrayView1D<int, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>,
70+
ArrayView1D<int, Stride1D.Dense>, int>(WebGLScatterRadixComputeDest);
71+
var exclusiveScan = accelerator.CreateScan<
72+
int, Stride1D.Dense, Stride1D.Dense, AddInt32>(ScanKind.Exclusive);
73+
74+
int numBits = default(TRadixSortOperation).NumBits; // 8
75+
76+
return (stream, view, temp) =>
77+
{
78+
int n = (int)view.Length;
79+
if (n <= 1)
80+
return;
81+
82+
using var keysA = accelerator.Allocate1D<float>(n);
83+
using var keysB = accelerator.Allocate1D<float>(n);
84+
using var flags = accelerator.Allocate1D<int>(n);
85+
using var onePrefix = accelerator.Allocate1D<int>(n);
86+
using var dest = accelerator.Allocate1D<int>(n);
87+
using var scanTemp = accelerator.Allocate1D<int>(1);
88+
89+
copyIn(stream, n, view, keysA.View);
90+
91+
var src = keysA;
92+
var dst = keysB;
93+
for (int bit = 0; bit < numBits; bit++)
94+
{
95+
extractBit(stream, n, src.View, flags.View, bit);
96+
exclusiveScan(stream, flags.View, onePrefix.View, scanTemp.View);
97+
computeDest(stream, n, flags.View, onePrefix.View, dest.View, n);
98+
scatter.Scatter(dst.View, src.View, dest.View, n, "float");
99+
var tmp = src; src = dst; dst = tmp;
100+
}
101+
102+
copyOut(stream, n, src.View, view);
103+
};
104+
}
105+
106+
107+
// Float8E4M3-KEY pairs sort (FP8 key + any 4/8-byte non-FP8 value). Keys use the
108+
// unpacked f32 working representation; values use the same int/float/uint scatter
109+
// program as the generic pairs path. Invoked by reflection from CreateRadixSortPairs.
110+
private static RadixSortPairs<Float8E4M3, TKeyStride, TValue, TValueStride>
111+
CreateWebGLScatterRadixSortPairsFloat8E4M3Key<
112+
TKeyStride, TValue, TValueStride, TRadixSortOperation>(
113+
Accelerator accelerator, IScatterProvider scatter)
114+
where TKeyStride : struct, IStride1D
115+
where TValue : unmanaged
116+
where TValueStride : struct, IStride1D
117+
where TRadixSortOperation : struct, IRadixSortOperation<Float8E4M3>
118+
{
119+
var copyInKeys = accelerator.LoadAutoGroupedKernel<
120+
Index1D, ArrayView1D<Float8E4M3, TKeyStride>, ArrayView1D<float, Stride1D.Dense>>(
121+
WebGLScatterRadixCopyInFloat8E4M3<TKeyStride>);
122+
var copyOutKeys = accelerator.LoadAutoGroupedKernel<
123+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<Float8E4M3, TKeyStride>>(
124+
WebGLScatterRadixCopyOutFloat8E4M3<TKeyStride>);
125+
var copyInVals = accelerator.LoadAutoGroupedKernel<
126+
Index1D, ArrayView1D<TValue, TValueStride>, ArrayView1D<TValue, Stride1D.Dense>>(
127+
WebGLScatterRadixCopyIn<TValue, TValueStride>);
128+
var copyOutVals = accelerator.LoadAutoGroupedKernel<
129+
Index1D, ArrayView1D<TValue, Stride1D.Dense>, ArrayView1D<TValue, TValueStride>>(
130+
WebGLScatterRadixCopyOut<TValue, TValueStride>);
131+
var extractBit = accelerator.LoadAutoGroupedKernel<
132+
Index1D, ArrayView1D<float, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>, int>(
133+
WebGLScatterRadixExtractBitFloat8E4M3<TRadixSortOperation>);
134+
var computeDest = accelerator.LoadAutoGroupedKernel<
135+
Index1D, ArrayView1D<int, Stride1D.Dense>, ArrayView1D<int, Stride1D.Dense>,
136+
ArrayView1D<int, Stride1D.Dense>, int>(WebGLScatterRadixComputeDest);
137+
var exclusiveScan = accelerator.CreateScan<
138+
int, Stride1D.Dense, Stride1D.Dense, AddInt32>(ScanKind.Exclusive);
139+
140+
int numBits = default(TRadixSortOperation).NumBits; // 8
141+
string valType = WebGLScatterValueType<TValue>();
142+
int valCpe = WebGLScatterCpe<TValue>();
143+
144+
return (stream, keys, values, tempView) =>
145+
{
146+
int n = (int)keys.Length;
147+
if (n <= 1)
148+
return;
149+
150+
using var keysA = accelerator.Allocate1D<float>(n);
151+
using var keysB = accelerator.Allocate1D<float>(n);
152+
using var valsA = accelerator.Allocate1D<TValue>(n);
153+
using var valsB = accelerator.Allocate1D<TValue>(n);
154+
using var flags = accelerator.Allocate1D<int>(n);
155+
using var onePrefix = accelerator.Allocate1D<int>(n);
156+
using var dest = accelerator.Allocate1D<int>(n);
157+
using var scanTemp = accelerator.Allocate1D<int>(1);
158+
159+
copyInKeys(stream, n, keys, keysA.View);
160+
copyInVals(stream, n, values, valsA.View);
161+
162+
var kSrc = keysA;
163+
var kDst = keysB;
164+
var vSrc = valsA;
165+
var vDst = valsB;
166+
for (int bit = 0; bit < numBits; bit++)
167+
{
168+
extractBit(stream, n, kSrc.View, flags.View, bit);
169+
exclusiveScan(stream, flags.View, onePrefix.View, scanTemp.View);
170+
computeDest(stream, n, flags.View, onePrefix.View, dest.View, n);
171+
scatter.Scatter(kDst.View, kSrc.View, dest.View, n, "float", 1);
172+
scatter.Scatter(vDst.View, vSrc.View, dest.View, n, valType, valCpe);
173+
var kt = kSrc; kSrc = kDst; kDst = kt;
174+
var vt = vSrc; vSrc = vDst; vDst = vt;
175+
}
176+
177+
copyOutKeys(stream, n, kSrc.View, keys);
178+
copyOutVals(stream, n, vSrc.View, values);
179+
};
180+
}
181+
}
182+
}

0 commit comments

Comments
 (0)