Skip to content

Commit 1f29845

Browse files
LostBeardclaude
andcommitted
FP4 PTX/CUDA backend: portable bit-manip convert - FP4 now on all 3 desktop backends (+ latent FP8/bf16 store fix)
Wires Float4E2M1 through the CUDA/PTX codegen, mirroring the FP8 PTX path (1dd4d5b). FP4 uses the f32-register model (value held as f32, rounded to the 1-byte E2M1 grid only at the store); all conversion is PORTABLE integer bit-manip (no native cvt) so it runs on EVERY CUDA arch incl. the 1080 (sm_61) / 2060 (sm_75) - the [[feedback-native-cvt-shortcuts-gate-out-older-hardware]] rule. - EmitFP4BitsToF32 / EmitF32ToFP4Bits: portable E2M1<->f32 (branchless setp/selp, RNE; finite overflow + +-Inf saturate to +-6, NaN->-0). Direct ports of the managed Float4E2M1 conversion (CPU-verified bit-exact to ml_dtypes.float4_e2m1fn). Decode is branch-light (E2M1's only subnormal is 0.5, no normalize loop); encode reuses the FP8 generic shape with E2M1 constants. - ConvertValue: FP4<->f32 (and FP4<->FP4) is a register no-op (f32-register model). - Load: ArrayView<FP4> -> ld.u8 -> EmitFP4BitsToF32. Store (target FP4): EmitF32ToFP4Bits -> st.u8. - FP4 scalar param: .b8 declaration + ld.param.u8 + widen via EmitFP4BitsToF32. - PTXCodeGenerator.Emitter: FP4 ConstantRegister emits its f32 magnitude (fixes the relu compare). - PTXInstructions: Float4E2M1 added to all 5 bf16/FP8->f32 remap sites (Select/Compare/Unary/ Binary/Ternary arithmetic). LATENT BUG FIXED (Rule 2a): the "low-precision value stored to a wider buffer" guard covered only bf16 - so `floatBuf[i] = (float)fp8Buf[i]` / `(float)fp4Buf[i]` re-narrowed the f32 register to a st.b8 into the 4-byte slot, reading back ~0 (caught here via the FP4 decode kernel + a PTX dump: `st.b8 [y+i*4], %f1`). Generalized the guard to bf16 + FP8E4M3 + FP8E5M2 + FP4E2M1: store the f32 bits directly as the target element type. FP8 was untested for decode-to-float so this was silent. VERIFIED (DemoConsole -- fp4-verify): CPU + OpenCL + CUDA ALL green - relu generic kernel 256/256 within 1 FP4 step, f32->FP4 convert 24/24 BIT-EXACT, FP4->f32 decode 16/16 BIT-EXACT. Regressions green: FP8 257/257 all 3 backends, bf16/Half oracle 65536/65536 bit-exact. Basic-ops-only => 4070-correct implies 1080/2060-correct. Browser backends (WGSL/GLSL/Wasm) next. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent e42f324 commit 1f29845

4 files changed

Lines changed: 296 additions & 22 deletions

File tree

ILGPU/Backends/PTX/PTXCodeGenerator.Emitter.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ public void AppendArgument(ConstantRegister argument)
188188
case BasicValueType.Float8E5M2:
189189
AppendConstant((float)value.Float8E5M2Value);
190190
break;
191+
case BasicValueType.Float4E2M1:
192+
// FP4 is held in an f32 register on PTX (same f32-register model as
193+
// bf16/FP8); emit the f32 magnitude as the immediate.
194+
AppendConstant((float)value.Float4E2M1Value);
195+
break;
191196
case BasicValueType.Float32:
192197
AppendConstant(value.Float32Value);
193198
break;

ILGPU/Backends/PTX/PTXCodeGenerator.Values.cs

Lines changed: 255 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,16 @@ public void GenerateCode(ConvertValue value)
241241
}
242242
}
243243

244-
// FP8 uses the SAME f32-register model: the FP8 value lives as f32 in-register and is
245-
// rounded to the 1-byte FP8 grid only at the store boundary (EmitF32ToFP8Bits). So an
246-
// FP8<->f32 (or FP8<->FP8) ConvertValue is a register no-op here - this is what makes
247-
// PrecisionConvert.ConvertToSingle/ConvertFromSingle<FP8> lower to nothing on PTX.
244+
// FP8 AND FP4 use the SAME f32-register model: the value lives as f32 in-register and is
245+
// rounded to the 1-byte grid only at the store boundary (EmitF32ToFP8Bits/EmitF32ToFP4Bits).
246+
// So a (FP8/FP4)<->f32 (or same-low-precision<->same) ConvertValue is a register no-op here -
247+
// this is what makes PrecisionConvert.ConvertToSingle/ConvertFromSingle<T> lower to nothing.
248248
bool srcFp8 = sourceType == ArithmeticBasicValueType.Float8E4M3
249-
|| sourceType == ArithmeticBasicValueType.Float8E5M2;
249+
|| sourceType == ArithmeticBasicValueType.Float8E5M2
250+
|| sourceType == ArithmeticBasicValueType.Float4E2M1;
250251
bool dstFp8 = targetType == ArithmeticBasicValueType.Float8E4M3
251-
|| targetType == ArithmeticBasicValueType.Float8E5M2;
252+
|| targetType == ArithmeticBasicValueType.Float8E5M2
253+
|| targetType == ArithmeticBasicValueType.Float4E2M1;
252254
if (srcFp8 || dstFp8)
253255
{
254256
if (srcFp8) sourceType = ArithmeticBasicValueType.Float32;
@@ -943,6 +945,200 @@ private void EmitF32ToFP8Bits(HardwareRegister srcF32, HardwareRegister dstByte,
943945
FreeRegister(p); FreeRegister(p2);
944946
}
945947

948+
/// <summary>
949+
/// Emits a PORTABLE FP4 E2M1 raw-nibble (low 4 bits in a .b16 reg) -&gt; f32 conversion using
950+
/// only basic integer ops (every CUDA arch). E2M1FN has 16 finite codes (NO Inf/NaN); the only
951+
/// subnormal is 0.5, so this is branch-light (no normalize loop). Byte-identical to the managed
952+
/// ConvertFloat4E2M1ToFloat (CPU-verified bit-exact to ml_dtypes.float4_e2m1fn). f32-register model.
953+
/// </summary>
954+
private void EmitFP4BitsToF32(HardwareRegister srcByte, HardwareRegister dstF32)
955+
{
956+
var bits = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
957+
var sign = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
958+
var e = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
959+
var m = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
960+
var nrm = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
961+
var sub = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
962+
var result = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
963+
var t = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
964+
var p = AllocateRegister(BasicValueType.Int1, PTXRegisterKind.Predicate);
965+
966+
void EmitI(string op, HardwareRegister d, HardwareRegister a, long imm)
967+
{ using var c = BeginCommand(op); c.AppendArgument(d); c.AppendArgument(a); c.AppendConstant(imm); }
968+
969+
// bits = (u32)srcByte & 0x0F
970+
using (var c = BeginCommand("cvt.u32.u16")) { c.AppendArgument(bits); c.AppendArgument(srcByte); }
971+
EmitI("and.b32", bits, bits, 0x0F);
972+
// sign = (bits & 0x8) << 28 (E2M1 sign bit 3 -> f32 sign bit 31)
973+
EmitI("and.b32", sign, bits, 0x8);
974+
EmitI("shl.b32", sign, sign, 28);
975+
// e = (bits >> 1) & 0x3 ; m = bits & 0x1
976+
EmitI("shr.u32", e, bits, 1);
977+
EmitI("and.b32", e, e, 0x3);
978+
EmitI("and.b32", m, bits, 0x1);
979+
980+
// NORMAL: nrm = sign | ((e - 1 + 127) << 23) | (m << 22)
981+
EmitI("add.s32", t, e, 126);
982+
EmitI("shl.b32", t, t, 23);
983+
using (var c = BeginCommand("or.b32")) { c.AppendArgument(nrm); c.AppendArgument(sign); c.AppendArgument(t); }
984+
EmitI("shl.b32", t, m, 22);
985+
using (var c = BeginCommand("or.b32")) { c.AppendArgument(nrm); c.AppendArgument(nrm); c.AppendArgument(t); }
986+
987+
// SUBNORMAL (e==0): m==0 -> sign (+-0) ; m==1 -> sign | (126<<23) (0.5)
988+
EmitI("or.b32", sub, sign, 126L << 23);
989+
using (var c = BeginCommand("setp.eq.s32")) { c.AppendArgument(p); c.AppendArgument(m); c.AppendConstant(0); }
990+
using (var c = BeginCommand("selp.b32")) { c.AppendArgument(sub); c.AppendArgument(sign); c.AppendArgument(sub); c.AppendArgument(p); }
991+
992+
// result = (e==0) ? sub : nrm
993+
using (var c = BeginCommand("setp.eq.s32")) { c.AppendArgument(p); c.AppendArgument(e); c.AppendConstant(0); }
994+
using (var c = BeginCommand("selp.b32")) { c.AppendArgument(result); c.AppendArgument(sub); c.AppendArgument(nrm); c.AppendArgument(p); }
995+
996+
using (var c = BeginCommand("mov.b32")) { c.AppendArgument(dstF32); c.AppendArgument(result); }
997+
998+
FreeRegister(bits); FreeRegister(sign); FreeRegister(e); FreeRegister(m);
999+
FreeRegister(nrm); FreeRegister(sub); FreeRegister(result); FreeRegister(t); FreeRegister(p);
1000+
}
1001+
1002+
/// <summary>
1003+
/// Emits a PORTABLE f32 -&gt; FP4 E2M1 raw-nibble (low 4 bits in dst .b16) conversion using only
1004+
/// basic integer ops (every CUDA arch). Branchless (setp/selp), RNE; finite overflow AND +-Inf
1005+
/// saturate to +-6 (0x7/0xF), NaN -&gt; -0 (0x8). Byte-identical to the managed ConvertFloatToFloat4E2M1
1006+
/// (CPU-verified bit-exact to ml_dtypes.float4_e2m1fn). The subnormal shift is clamped (PTX shr is
1007+
/// UB for shift&gt;=32) and edge-guarded to match the managed return-0 cases.
1008+
/// </summary>
1009+
private void EmitF32ToFP4Bits(HardwareRegister srcF32, HardwareRegister dstByte)
1010+
{
1011+
const int mantBits = 1, bias = 1, dropBits = 22, eMin = 0;
1012+
1013+
var bits = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1014+
var sign = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1015+
var rest = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1016+
var f32Exp = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1017+
var f32Mant = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1018+
var ev = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1019+
var result = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1020+
var nrm = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1021+
var sub = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1022+
var signif = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1023+
var shift = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1024+
var sshift = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1025+
var mt = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1026+
var rb = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1027+
var stk = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1028+
var t = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1029+
var t2 = AllocateRegister(BasicValueType.Int32, PTXRegisterKind.Int32);
1030+
var p = AllocateRegister(BasicValueType.Int1, PTXRegisterKind.Predicate);
1031+
var p2 = AllocateRegister(BasicValueType.Int1, PTXRegisterKind.Predicate);
1032+
1033+
void Emit(string op, params HardwareRegister[] a) { using var c = BeginCommand(op); foreach (var x in a) c.AppendArgument(x); }
1034+
void EmitI(string op, HardwareRegister d, HardwareRegister a, long imm) { using var c = BeginCommand(op); c.AppendArgument(d); c.AppendArgument(a); c.AppendConstant(imm); }
1035+
void MovI(HardwareRegister d, long imm) { using var c = BeginCommand("mov.u32"); c.AppendArgument(d); c.AppendConstant(imm); }
1036+
void SetpI(string op, HardwareRegister pr, HardwareRegister a, long imm) { using var c = BeginCommand(op); c.AppendArgument(pr); c.AppendArgument(a); c.AppendConstant(imm); }
1037+
void Selp(HardwareRegister d, HardwareRegister tv, HardwareRegister fv, HardwareRegister pr) { using var c = BeginCommand("selp.b32"); c.AppendArgument(d); c.AppendArgument(tv); c.AppendArgument(fv); c.AppendArgument(pr); }
1038+
1039+
// bits = reinterpret(srcF32); sign = (bits>>28)&0x8 (E2M1 sign bit 3); rest = bits & 0x7FFFFFFF
1040+
using (var c = BeginCommand("mov.b32")) { c.AppendArgument(bits); c.AppendArgument(srcF32); }
1041+
EmitI("shr.u32", sign, bits, 28);
1042+
EmitI("and.b32", sign, sign, 0x8);
1043+
EmitI("and.b32", rest, bits, 0x7FFFFFFF);
1044+
// f32Exp = (rest>>23)&0xFF; f32Mant = rest & 0x7FFFFF; ev = f32Exp - 127
1045+
EmitI("shr.u32", f32Exp, rest, 23);
1046+
EmitI("and.b32", f32Exp, f32Exp, 0xFF);
1047+
EmitI("and.b32", f32Mant, rest, 0x7FFFFF);
1048+
EmitI("sub.s32", ev, f32Exp, 127);
1049+
1050+
// ---- NORMAL candidate (ev in 0..2): round 23->1 RNE ----
1051+
EmitI("shr.u32", mt, f32Mant, dropBits);
1052+
EmitI("shr.u32", rb, f32Mant, dropBits - 1);
1053+
EmitI("and.b32", rb, rb, 1);
1054+
EmitI("and.b32", t, f32Mant, (1 << (dropBits - 1)) - 1);
1055+
using (var c = BeginCommand("setp.ne.s32")) { c.AppendArgument(p); c.AppendArgument(t); c.AppendConstant(0); }
1056+
MovI(stk, 0); MovI(t2, 1); Selp(stk, t2, stk, p);
1057+
// nrm = ((ev+bias)<<mantBits) | mt
1058+
EmitI("add.s32", t, ev, bias);
1059+
EmitI("shl.b32", t, t, mantBits);
1060+
Emit("or.b32", nrm, t, mt);
1061+
// roundUp if rb==1 && (stk!=0 || (mt&1))
1062+
EmitI("and.b32", t, mt, 1);
1063+
Emit("or.b32", t, stk, t);
1064+
SetpI("setp.ne.s32", p, t, 0);
1065+
SetpI("setp.eq.s32", p2, rb, 1);
1066+
using (var c = BeginCommand("and.pred")) { c.AppendArgument(p); c.AppendArgument(p); c.AppendArgument(p2); }
1067+
EmitI("add.s32", t, nrm, 1);
1068+
Selp(nrm, t, nrm, p);
1069+
// saturate: a carry past +-6 (nrm > 0x7) clamps to 0x7 (no larger finite, no Inf)
1070+
using (var c = BeginCommand("setp.gt.u32")) { c.AppendArgument(p); c.AppendArgument(nrm); c.AppendConstant(0x7); }
1071+
MovI(t2, 0x7); Selp(nrm, t2, nrm, p);
1072+
EmitI("and.b32", nrm, nrm, 0x7);
1073+
Emit("or.b32", nrm, sign, nrm);
1074+
1075+
// ---- SUBNORMAL candidate (ev < 0) ----
1076+
// signif = f32Mant | 0x800000 ; shift = (eMin - ev) + dropBits (= -ev + 22)
1077+
EmitI("or.b32", signif, f32Mant, 0x800000);
1078+
MovI(t, eMin);
1079+
Emit("sub.s32", shift, t, ev);
1080+
EmitI("add.s32", shift, shift, dropBits);
1081+
// sshift = min(shift, 31)
1082+
MovI(t, 31);
1083+
using (var c = BeginCommand("min.s32")) { c.AppendArgument(sshift); c.AppendArgument(shift); c.AppendArgument(t); }
1084+
// mt = signif >> sshift
1085+
Emit("shr.u32", mt, signif, sshift);
1086+
// rb = (signif >> (sshift-1)) & 1
1087+
EmitI("sub.s32", t, sshift, 1);
1088+
Emit("shr.u32", rb, signif, t);
1089+
EmitI("and.b32", rb, rb, 1);
1090+
// stk = (signif & ((1<<(sshift-1))-1)) != 0
1091+
MovI(t2, 1);
1092+
Emit("shl.b32", t2, t2, t);
1093+
EmitI("sub.s32", t2, t2, 1);
1094+
Emit("and.b32", t2, signif, t2);
1095+
SetpI("setp.ne.s32", p, t2, 0);
1096+
MovI(stk, 0); MovI(t, 1); Selp(stk, t, stk, p);
1097+
// roundUp if rb==1 && (stk || mt&1)
1098+
EmitI("and.b32", t, mt, 1);
1099+
Emit("or.b32", t, stk, t);
1100+
SetpI("setp.ne.s32", p, t, 0);
1101+
SetpI("setp.eq.s32", p2, rb, 1);
1102+
using (var c = BeginCommand("and.pred")) { c.AppendArgument(p); c.AppendArgument(p); c.AppendArgument(p2); }
1103+
EmitI("add.s32", t, mt, 1);
1104+
Selp(mt, t, mt, p);
1105+
// sub = sign | (mt & 0x7)
1106+
EmitI("and.b32", t, mt, 0x7);
1107+
Emit("or.b32", sub, sign, t);
1108+
// guards: f32Exp==0 -> sign ; shift>31 -> sign
1109+
SetpI("setp.eq.s32", p, f32Exp, 0);
1110+
Selp(sub, sign, sub, p);
1111+
SetpI("setp.gt.s32", p, shift, 31);
1112+
Selp(sub, sign, sub, p);
1113+
1114+
// ---- assemble: result = normal; if ev<0 -> sub; overflow ev>2 -> +-6; NaN/Inf special ----
1115+
Emit("mov.u32", result, nrm);
1116+
SetpI("setp.lt.s32", p, ev, eMin);
1117+
Selp(result, sub, result, p);
1118+
// finite overflow ev>2 -> sign|0x7 (saturate to +-6)
1119+
SetpI("setp.gt.s32", p, ev, 2);
1120+
EmitI("or.b32", t, sign, 0x7);
1121+
Selp(result, t, result, p);
1122+
// Inf (rest == 0x7F800000) -> sign|0x7 (+-6)
1123+
EmitI("or.b32", t, sign, 0x7);
1124+
using (var c = BeginCommand("setp.eq.s32")) { c.AppendArgument(p); c.AppendArgument(rest); c.AppendConstant(0x7F800000); }
1125+
Selp(result, t, result, p);
1126+
// NaN (rest > 0x7F800000) -> 0x8 (-0), UNCONDITIONAL (no sign) - matches ml_dtypes
1127+
MovI(t2, 0x8);
1128+
using (var c = BeginCommand("setp.gt.u32")) { c.AppendArgument(p); c.AppendArgument(rest); c.AppendConstant(0x7F800000); }
1129+
Selp(result, t2, result, p);
1130+
1131+
// dstByte = (u16)(result & 0xFF)
1132+
EmitI("and.b32", result, result, 0xFF);
1133+
using (var c = BeginCommand("cvt.u16.u32")) { c.AppendArgument(dstByte); c.AppendArgument(result); }
1134+
1135+
FreeRegister(bits); FreeRegister(sign); FreeRegister(rest); FreeRegister(f32Exp);
1136+
FreeRegister(f32Mant); FreeRegister(ev); FreeRegister(result); FreeRegister(nrm);
1137+
FreeRegister(sub); FreeRegister(signif); FreeRegister(shift); FreeRegister(sshift);
1138+
FreeRegister(mt); FreeRegister(rb); FreeRegister(stk); FreeRegister(t); FreeRegister(t2);
1139+
FreeRegister(p); FreeRegister(p2);
1140+
}
1141+
9461142
/// <summary cref="IBackendCodeGenerator.GenerateCode(Load)"/>
9471143
public void GenerateCode(Load load)
9481144
{
@@ -994,6 +1190,25 @@ public void GenerateCode(Load load)
9941190
return;
9951191
}
9961192

1193+
if (load.Type.BasicValueType == BasicValueType.Float4E2M1)
1194+
{
1195+
// FP4 storage is a packed 1-byte value (4-bit E2M1 in the low nibble); load the byte
1196+
// into a temp .b16 register, then widen to the f32 value register via portable bit-manip
1197+
// (EmitFP4BitsToF32 - every CUDA arch). f32-register model like bf16/FP8.
1198+
var fp4Target = AllocateHardware(load);
1199+
var rawReg = AllocateRegister(BasicValueType.Int16, PTXRegisterKind.Int16);
1200+
using (var cmd = BeginCommand(PTXInstructions.LoadOperation))
1201+
{
1202+
cmd.AppendAddressSpace(sourceType.AddressSpace);
1203+
cmd.AppendSuffix("u8");
1204+
cmd.AppendArgument(rawReg);
1205+
cmd.AppendArgumentValue(address, 0);
1206+
}
1207+
EmitFP4BitsToF32(rawReg, fp4Target);
1208+
FreeRegister(rawReg);
1209+
return;
1210+
}
1211+
9971212
var targetRegister = Allocate(load);
9981213

9991214
EmitVectorizedCommand(
@@ -1149,18 +1364,41 @@ public void GenerateCode(Store store)
11491364
return;
11501365
}
11511366

1152-
// A bf16-TYPED value stored to a NON-bf16 buffer (the target-bf16 case was handled above).
1153-
// bf16 is held in an f32 register and the `(float)bf16` widening Convert is a no-op alias
1154-
// that preserves the bf16 IR type, so `floatBuf[i] = (float)bf16Buf[i]` reaches here with a
1155-
// bf16-typed value register. Falling through to EmitIOStore would re-narrow it (cvt.rn.bf16.f32
1156-
// + st.b16) into the wider (e.g. 4-byte f32) destination slot -> the value reads back ~0
1157-
// (Tuvok's "bf16 store/load returns zeros" bug). Store the f32 bits directly as the target
1158-
// element type instead. (Struct-field bf16 stores keep using EmitIOStore: there the register
1159-
// type and the field storage type agree, so its register-type-keyed packing is correct.)
1160-
if (value is PrimitiveRegister bf16Value &&
1161-
bf16Value.BasicValueType == BasicValueType.BFloat16)
1367+
if (targetType.ElementType.BasicValueType == BasicValueType.Float4E2M1)
11621368
{
1163-
var f32Reg = EnsureHardwareRegister(bf16Value);
1369+
// FP4 store: round the f32 value register to the 1-byte E2M1 pattern (low nibble) via
1370+
// portable bit-manip (EmitF32ToFP4Bits - every CUDA arch) into a temp .b16 register, then
1371+
// write the low byte. Keyed off the TARGET BUFFER element type (same as bf16/FP8).
1372+
var valueReg = EnsureHardwareRegister(value.AsNotNullCast<PrimitiveRegister>());
1373+
var rawReg = AllocateRegister(BasicValueType.Int16, PTXRegisterKind.Int16);
1374+
EmitF32ToFP4Bits(valueReg, rawReg);
1375+
using (var cmd = BeginCommand(PTXInstructions.StoreOperation))
1376+
{
1377+
cmd.AppendAddressSpace(targetType.AddressSpace);
1378+
cmd.AppendSuffix("u8");
1379+
cmd.AppendArgumentValue(address, 0);
1380+
cmd.AppendArgument(rawReg);
1381+
}
1382+
FreeRegister(rawReg);
1383+
return;
1384+
}
1385+
1386+
// A low-precision-TYPED value (bf16 / FP8 / FP4 - all held in an f32 register on PTX) stored
1387+
// to a NON-matching buffer (the target-matching cases were handled above). The widening
1388+
// Convert (`(float)bf16` etc.) is a no-op alias that preserves the low-precision IR type, so
1389+
// `floatBuf[i] = (float)lowpBuf[i]` reaches here with a low-precision-typed value register that
1390+
// actually holds the widened f32 value. Falling through to EmitIOStore would re-narrow it
1391+
// (cvt/round + st.b16/st.b8) into the wider (e.g. 4-byte f32) destination slot -> the value
1392+
// reads back ~0 (Tuvok's "bf16 store/load returns zeros" bug; the same latent bug existed for
1393+
// FP8/FP4). Store the f32 bits directly as the target element type instead. (Struct-field
1394+
// stores keep using EmitIOStore: there the register type and field storage type agree.)
1395+
if (value is PrimitiveRegister lowpValue &&
1396+
(lowpValue.BasicValueType == BasicValueType.BFloat16 ||
1397+
lowpValue.BasicValueType == BasicValueType.Float8E4M3 ||
1398+
lowpValue.BasicValueType == BasicValueType.Float8E5M2 ||
1399+
lowpValue.BasicValueType == BasicValueType.Float4E2M1))
1400+
{
1401+
var f32Reg = EnsureHardwareRegister(lowpValue);
11641402
using var cmd = BeginCommand(PTXInstructions.StoreOperation);
11651403
cmd.AppendAddressSpace(targetType.AddressSpace);
11661404
cmd.AppendSuffix(ResolveIOType(targetType.ElementType.BasicValueType));

0 commit comments

Comments
 (0)