@@ -241,14 +241,16 @@ public void GenerateCode(ConvertValue value)
241241 }
242242 }
243243
244- // FP8 uses the SAME f32-register model: the FP8 value lives as f32 in-register and is
245- // rounded to the 1-byte FP8 grid only at the store boundary (EmitF32ToFP8Bits). So an
246- // FP8<->f32 (or FP8 <->FP8 ) ConvertValue is a register no-op here - this is what makes
247- // PrecisionConvert.ConvertToSingle/ConvertFromSingle<FP8 > lower to nothing on PTX .
244+ // FP8 AND FP4 use the SAME f32-register model: the value lives as f32 in-register and is
245+ // rounded to the 1-byte grid only at the store boundary (EmitF32ToFP8Bits/EmitF32ToFP4Bits).
246+ // So a ( FP8/FP4) <->f32 (or same-low-precision <->same ) ConvertValue is a register no-op here -
247+ // this is what makes PrecisionConvert.ConvertToSingle/ConvertFromSingle<T > lower to nothing.
248248 bool srcFp8 = sourceType == ArithmeticBasicValueType . Float8E4M3
249- || sourceType == ArithmeticBasicValueType . Float8E5M2 ;
249+ || sourceType == ArithmeticBasicValueType . Float8E5M2
250+ || sourceType == ArithmeticBasicValueType . Float4E2M1 ;
250251 bool dstFp8 = targetType == ArithmeticBasicValueType . Float8E4M3
251- || targetType == ArithmeticBasicValueType . Float8E5M2 ;
252+ || targetType == ArithmeticBasicValueType . Float8E5M2
253+ || targetType == ArithmeticBasicValueType . Float4E2M1 ;
252254 if ( srcFp8 || dstFp8 )
253255 {
254256 if ( srcFp8 ) sourceType = ArithmeticBasicValueType . Float32 ;
@@ -943,6 +945,200 @@ private void EmitF32ToFP8Bits(HardwareRegister srcF32, HardwareRegister dstByte,
943945 FreeRegister ( p ) ; FreeRegister ( p2 ) ;
944946 }
945947
948+ /// <summary>
949+ /// Emits a PORTABLE FP4 E2M1 raw-nibble (low 4 bits in a .b16 reg) -> f32 conversion using
950+ /// only basic integer ops (every CUDA arch). E2M1FN has 16 finite codes (NO Inf/NaN); the only
951+ /// subnormal is 0.5, so this is branch-light (no normalize loop). Byte-identical to the managed
952+ /// ConvertFloat4E2M1ToFloat (CPU-verified bit-exact to ml_dtypes.float4_e2m1fn). f32-register model.
953+ /// </summary>
954+ private void EmitFP4BitsToF32 ( HardwareRegister srcByte , HardwareRegister dstF32 )
955+ {
956+ var bits = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
957+ var sign = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
958+ var e = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
959+ var m = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
960+ var nrm = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
961+ var sub = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
962+ var result = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
963+ var t = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
964+ var p = AllocateRegister ( BasicValueType . Int1 , PTXRegisterKind . Predicate ) ;
965+
966+ void EmitI ( string op , HardwareRegister d , HardwareRegister a , long imm )
967+ { using var c = BeginCommand ( op ) ; c . AppendArgument ( d ) ; c . AppendArgument ( a ) ; c . AppendConstant ( imm ) ; }
968+
969+ // bits = (u32)srcByte & 0x0F
970+ using ( var c = BeginCommand ( "cvt.u32.u16" ) ) { c . AppendArgument ( bits ) ; c . AppendArgument ( srcByte ) ; }
971+ EmitI ( "and.b32" , bits , bits , 0x0F ) ;
972+ // sign = (bits & 0x8) << 28 (E2M1 sign bit 3 -> f32 sign bit 31)
973+ EmitI ( "and.b32" , sign , bits , 0x8 ) ;
974+ EmitI ( "shl.b32" , sign , sign , 28 ) ;
975+ // e = (bits >> 1) & 0x3 ; m = bits & 0x1
976+ EmitI ( "shr.u32" , e , bits , 1 ) ;
977+ EmitI ( "and.b32" , e , e , 0x3 ) ;
978+ EmitI ( "and.b32" , m , bits , 0x1 ) ;
979+
980+ // NORMAL: nrm = sign | ((e - 1 + 127) << 23) | (m << 22)
981+ EmitI ( "add.s32" , t , e , 126 ) ;
982+ EmitI ( "shl.b32" , t , t , 23 ) ;
983+ using ( var c = BeginCommand ( "or.b32" ) ) { c . AppendArgument ( nrm ) ; c . AppendArgument ( sign ) ; c . AppendArgument ( t ) ; }
984+ EmitI ( "shl.b32" , t , m , 22 ) ;
985+ using ( var c = BeginCommand ( "or.b32" ) ) { c . AppendArgument ( nrm ) ; c . AppendArgument ( nrm ) ; c . AppendArgument ( t ) ; }
986+
987+ // SUBNORMAL (e==0): m==0 -> sign (+-0) ; m==1 -> sign | (126<<23) (0.5)
988+ EmitI ( "or.b32" , sub , sign , 126L << 23 ) ;
989+ using ( var c = BeginCommand ( "setp.eq.s32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( m ) ; c . AppendConstant ( 0 ) ; }
990+ using ( var c = BeginCommand ( "selp.b32" ) ) { c . AppendArgument ( sub ) ; c . AppendArgument ( sign ) ; c . AppendArgument ( sub ) ; c . AppendArgument ( p ) ; }
991+
992+ // result = (e==0) ? sub : nrm
993+ using ( var c = BeginCommand ( "setp.eq.s32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( e ) ; c . AppendConstant ( 0 ) ; }
994+ using ( var c = BeginCommand ( "selp.b32" ) ) { c . AppendArgument ( result ) ; c . AppendArgument ( sub ) ; c . AppendArgument ( nrm ) ; c . AppendArgument ( p ) ; }
995+
996+ using ( var c = BeginCommand ( "mov.b32" ) ) { c . AppendArgument ( dstF32 ) ; c . AppendArgument ( result ) ; }
997+
998+ FreeRegister ( bits ) ; FreeRegister ( sign ) ; FreeRegister ( e ) ; FreeRegister ( m ) ;
999+ FreeRegister ( nrm ) ; FreeRegister ( sub ) ; FreeRegister ( result ) ; FreeRegister ( t ) ; FreeRegister ( p ) ;
1000+ }
1001+
1002+ /// <summary>
1003+ /// Emits a PORTABLE f32 -> FP4 E2M1 raw-nibble (low 4 bits in dst .b16) conversion using only
1004+ /// basic integer ops (every CUDA arch). Branchless (setp/selp), RNE; finite overflow AND +-Inf
1005+ /// saturate to +-6 (0x7/0xF), NaN -> -0 (0x8). Byte-identical to the managed ConvertFloatToFloat4E2M1
1006+ /// (CPU-verified bit-exact to ml_dtypes.float4_e2m1fn). The subnormal shift is clamped (PTX shr is
1007+ /// UB for shift>=32) and edge-guarded to match the managed return-0 cases.
1008+ /// </summary>
1009+ private void EmitF32ToFP4Bits ( HardwareRegister srcF32 , HardwareRegister dstByte )
1010+ {
1011+ const int mantBits = 1 , bias = 1 , dropBits = 22 , eMin = 0 ;
1012+
1013+ var bits = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1014+ var sign = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1015+ var rest = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1016+ var f32Exp = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1017+ var f32Mant = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1018+ var ev = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1019+ var result = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1020+ var nrm = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1021+ var sub = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1022+ var signif = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1023+ var shift = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1024+ var sshift = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1025+ var mt = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1026+ var rb = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1027+ var stk = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1028+ var t = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1029+ var t2 = AllocateRegister ( BasicValueType . Int32 , PTXRegisterKind . Int32 ) ;
1030+ var p = AllocateRegister ( BasicValueType . Int1 , PTXRegisterKind . Predicate ) ;
1031+ var p2 = AllocateRegister ( BasicValueType . Int1 , PTXRegisterKind . Predicate ) ;
1032+
1033+ void Emit ( string op , params HardwareRegister [ ] a ) { using var c = BeginCommand ( op ) ; foreach ( var x in a ) c . AppendArgument ( x ) ; }
1034+ void EmitI ( string op , HardwareRegister d , HardwareRegister a , long imm ) { using var c = BeginCommand ( op ) ; c . AppendArgument ( d ) ; c . AppendArgument ( a ) ; c . AppendConstant ( imm ) ; }
1035+ void MovI ( HardwareRegister d , long imm ) { using var c = BeginCommand ( "mov.u32" ) ; c . AppendArgument ( d ) ; c . AppendConstant ( imm ) ; }
1036+ void SetpI ( string op , HardwareRegister pr , HardwareRegister a , long imm ) { using var c = BeginCommand ( op ) ; c . AppendArgument ( pr ) ; c . AppendArgument ( a ) ; c . AppendConstant ( imm ) ; }
1037+ void Selp ( HardwareRegister d , HardwareRegister tv , HardwareRegister fv , HardwareRegister pr ) { using var c = BeginCommand ( "selp.b32" ) ; c . AppendArgument ( d ) ; c . AppendArgument ( tv ) ; c . AppendArgument ( fv ) ; c . AppendArgument ( pr ) ; }
1038+
1039+ // bits = reinterpret(srcF32); sign = (bits>>28)&0x8 (E2M1 sign bit 3); rest = bits & 0x7FFFFFFF
1040+ using ( var c = BeginCommand ( "mov.b32" ) ) { c . AppendArgument ( bits ) ; c . AppendArgument ( srcF32 ) ; }
1041+ EmitI ( "shr.u32" , sign , bits , 28 ) ;
1042+ EmitI ( "and.b32" , sign , sign , 0x8 ) ;
1043+ EmitI ( "and.b32" , rest , bits , 0x7FFFFFFF ) ;
1044+ // f32Exp = (rest>>23)&0xFF; f32Mant = rest & 0x7FFFFF; ev = f32Exp - 127
1045+ EmitI ( "shr.u32" , f32Exp , rest , 23 ) ;
1046+ EmitI ( "and.b32" , f32Exp , f32Exp , 0xFF ) ;
1047+ EmitI ( "and.b32" , f32Mant , rest , 0x7FFFFF ) ;
1048+ EmitI ( "sub.s32" , ev , f32Exp , 127 ) ;
1049+
1050+ // ---- NORMAL candidate (ev in 0..2): round 23->1 RNE ----
1051+ EmitI ( "shr.u32" , mt , f32Mant , dropBits ) ;
1052+ EmitI ( "shr.u32" , rb , f32Mant , dropBits - 1 ) ;
1053+ EmitI ( "and.b32" , rb , rb , 1 ) ;
1054+ EmitI ( "and.b32" , t , f32Mant , ( 1 << ( dropBits - 1 ) ) - 1 ) ;
1055+ using ( var c = BeginCommand ( "setp.ne.s32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( t ) ; c . AppendConstant ( 0 ) ; }
1056+ MovI ( stk , 0 ) ; MovI ( t2 , 1 ) ; Selp ( stk , t2 , stk , p ) ;
1057+ // nrm = ((ev+bias)<<mantBits) | mt
1058+ EmitI ( "add.s32" , t , ev , bias ) ;
1059+ EmitI ( "shl.b32" , t , t , mantBits ) ;
1060+ Emit ( "or.b32" , nrm , t , mt ) ;
1061+ // roundUp if rb==1 && (stk!=0 || (mt&1))
1062+ EmitI ( "and.b32" , t , mt , 1 ) ;
1063+ Emit ( "or.b32" , t , stk , t ) ;
1064+ SetpI ( "setp.ne.s32" , p , t , 0 ) ;
1065+ SetpI ( "setp.eq.s32" , p2 , rb , 1 ) ;
1066+ using ( var c = BeginCommand ( "and.pred" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( p ) ; c . AppendArgument ( p2 ) ; }
1067+ EmitI ( "add.s32" , t , nrm , 1 ) ;
1068+ Selp ( nrm , t , nrm , p ) ;
1069+ // saturate: a carry past +-6 (nrm > 0x7) clamps to 0x7 (no larger finite, no Inf)
1070+ using ( var c = BeginCommand ( "setp.gt.u32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( nrm ) ; c . AppendConstant ( 0x7 ) ; }
1071+ MovI ( t2 , 0x7 ) ; Selp ( nrm , t2 , nrm , p ) ;
1072+ EmitI ( "and.b32" , nrm , nrm , 0x7 ) ;
1073+ Emit ( "or.b32" , nrm , sign , nrm ) ;
1074+
1075+ // ---- SUBNORMAL candidate (ev < 0) ----
1076+ // signif = f32Mant | 0x800000 ; shift = (eMin - ev) + dropBits (= -ev + 22)
1077+ EmitI ( "or.b32" , signif , f32Mant , 0x800000 ) ;
1078+ MovI ( t , eMin ) ;
1079+ Emit ( "sub.s32" , shift , t , ev ) ;
1080+ EmitI ( "add.s32" , shift , shift , dropBits ) ;
1081+ // sshift = min(shift, 31)
1082+ MovI ( t , 31 ) ;
1083+ using ( var c = BeginCommand ( "min.s32" ) ) { c . AppendArgument ( sshift ) ; c . AppendArgument ( shift ) ; c . AppendArgument ( t ) ; }
1084+ // mt = signif >> sshift
1085+ Emit ( "shr.u32" , mt , signif , sshift ) ;
1086+ // rb = (signif >> (sshift-1)) & 1
1087+ EmitI ( "sub.s32" , t , sshift , 1 ) ;
1088+ Emit ( "shr.u32" , rb , signif , t ) ;
1089+ EmitI ( "and.b32" , rb , rb , 1 ) ;
1090+ // stk = (signif & ((1<<(sshift-1))-1)) != 0
1091+ MovI ( t2 , 1 ) ;
1092+ Emit ( "shl.b32" , t2 , t2 , t ) ;
1093+ EmitI ( "sub.s32" , t2 , t2 , 1 ) ;
1094+ Emit ( "and.b32" , t2 , signif , t2 ) ;
1095+ SetpI ( "setp.ne.s32" , p , t2 , 0 ) ;
1096+ MovI ( stk , 0 ) ; MovI ( t , 1 ) ; Selp ( stk , t , stk , p ) ;
1097+ // roundUp if rb==1 && (stk || mt&1)
1098+ EmitI ( "and.b32" , t , mt , 1 ) ;
1099+ Emit ( "or.b32" , t , stk , t ) ;
1100+ SetpI ( "setp.ne.s32" , p , t , 0 ) ;
1101+ SetpI ( "setp.eq.s32" , p2 , rb , 1 ) ;
1102+ using ( var c = BeginCommand ( "and.pred" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( p ) ; c . AppendArgument ( p2 ) ; }
1103+ EmitI ( "add.s32" , t , mt , 1 ) ;
1104+ Selp ( mt , t , mt , p ) ;
1105+ // sub = sign | (mt & 0x7)
1106+ EmitI ( "and.b32" , t , mt , 0x7 ) ;
1107+ Emit ( "or.b32" , sub , sign , t ) ;
1108+ // guards: f32Exp==0 -> sign ; shift>31 -> sign
1109+ SetpI ( "setp.eq.s32" , p , f32Exp , 0 ) ;
1110+ Selp ( sub , sign , sub , p ) ;
1111+ SetpI ( "setp.gt.s32" , p , shift , 31 ) ;
1112+ Selp ( sub , sign , sub , p ) ;
1113+
1114+ // ---- assemble: result = normal; if ev<0 -> sub; overflow ev>2 -> +-6; NaN/Inf special ----
1115+ Emit ( "mov.u32" , result , nrm ) ;
1116+ SetpI ( "setp.lt.s32" , p , ev , eMin ) ;
1117+ Selp ( result , sub , result , p ) ;
1118+ // finite overflow ev>2 -> sign|0x7 (saturate to +-6)
1119+ SetpI ( "setp.gt.s32" , p , ev , 2 ) ;
1120+ EmitI ( "or.b32" , t , sign , 0x7 ) ;
1121+ Selp ( result , t , result , p ) ;
1122+ // Inf (rest == 0x7F800000) -> sign|0x7 (+-6)
1123+ EmitI ( "or.b32" , t , sign , 0x7 ) ;
1124+ using ( var c = BeginCommand ( "setp.eq.s32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( rest ) ; c . AppendConstant ( 0x7F800000 ) ; }
1125+ Selp ( result , t , result , p ) ;
1126+ // NaN (rest > 0x7F800000) -> 0x8 (-0), UNCONDITIONAL (no sign) - matches ml_dtypes
1127+ MovI ( t2 , 0x8 ) ;
1128+ using ( var c = BeginCommand ( "setp.gt.u32" ) ) { c . AppendArgument ( p ) ; c . AppendArgument ( rest ) ; c . AppendConstant ( 0x7F800000 ) ; }
1129+ Selp ( result , t2 , result , p ) ;
1130+
1131+ // dstByte = (u16)(result & 0xFF)
1132+ EmitI ( "and.b32" , result , result , 0xFF ) ;
1133+ using ( var c = BeginCommand ( "cvt.u16.u32" ) ) { c . AppendArgument ( dstByte ) ; c . AppendArgument ( result ) ; }
1134+
1135+ FreeRegister ( bits ) ; FreeRegister ( sign ) ; FreeRegister ( rest ) ; FreeRegister ( f32Exp ) ;
1136+ FreeRegister ( f32Mant ) ; FreeRegister ( ev ) ; FreeRegister ( result ) ; FreeRegister ( nrm ) ;
1137+ FreeRegister ( sub ) ; FreeRegister ( signif ) ; FreeRegister ( shift ) ; FreeRegister ( sshift ) ;
1138+ FreeRegister ( mt ) ; FreeRegister ( rb ) ; FreeRegister ( stk ) ; FreeRegister ( t ) ; FreeRegister ( t2 ) ;
1139+ FreeRegister ( p ) ; FreeRegister ( p2 ) ;
1140+ }
1141+
9461142 /// <summary cref="IBackendCodeGenerator.GenerateCode(Load)"/>
9471143 public void GenerateCode ( Load load )
9481144 {
@@ -994,6 +1190,25 @@ public void GenerateCode(Load load)
9941190 return ;
9951191 }
9961192
1193+ if ( load . Type . BasicValueType == BasicValueType . Float4E2M1 )
1194+ {
1195+ // FP4 storage is a packed 1-byte value (4-bit E2M1 in the low nibble); load the byte
1196+ // into a temp .b16 register, then widen to the f32 value register via portable bit-manip
1197+ // (EmitFP4BitsToF32 - every CUDA arch). f32-register model like bf16/FP8.
1198+ var fp4Target = AllocateHardware ( load ) ;
1199+ var rawReg = AllocateRegister ( BasicValueType . Int16 , PTXRegisterKind . Int16 ) ;
1200+ using ( var cmd = BeginCommand ( PTXInstructions . LoadOperation ) )
1201+ {
1202+ cmd . AppendAddressSpace ( sourceType . AddressSpace ) ;
1203+ cmd . AppendSuffix ( "u8" ) ;
1204+ cmd . AppendArgument ( rawReg ) ;
1205+ cmd . AppendArgumentValue ( address , 0 ) ;
1206+ }
1207+ EmitFP4BitsToF32 ( rawReg , fp4Target ) ;
1208+ FreeRegister ( rawReg ) ;
1209+ return ;
1210+ }
1211+
9971212 var targetRegister = Allocate ( load ) ;
9981213
9991214 EmitVectorizedCommand (
@@ -1149,18 +1364,41 @@ public void GenerateCode(Store store)
11491364 return ;
11501365 }
11511366
1152- // A bf16-TYPED value stored to a NON-bf16 buffer (the target-bf16 case was handled above).
1153- // bf16 is held in an f32 register and the `(float)bf16` widening Convert is a no-op alias
1154- // that preserves the bf16 IR type, so `floatBuf[i] = (float)bf16Buf[i]` reaches here with a
1155- // bf16-typed value register. Falling through to EmitIOStore would re-narrow it (cvt.rn.bf16.f32
1156- // + st.b16) into the wider (e.g. 4-byte f32) destination slot -> the value reads back ~0
1157- // (Tuvok's "bf16 store/load returns zeros" bug). Store the f32 bits directly as the target
1158- // element type instead. (Struct-field bf16 stores keep using EmitIOStore: there the register
1159- // type and the field storage type agree, so its register-type-keyed packing is correct.)
1160- if ( value is PrimitiveRegister bf16Value &&
1161- bf16Value . BasicValueType == BasicValueType . BFloat16 )
1367+ if ( targetType . ElementType . BasicValueType == BasicValueType . Float4E2M1 )
11621368 {
1163- var f32Reg = EnsureHardwareRegister ( bf16Value ) ;
1369+ // FP4 store: round the f32 value register to the 1-byte E2M1 pattern (low nibble) via
1370+ // portable bit-manip (EmitF32ToFP4Bits - every CUDA arch) into a temp .b16 register, then
1371+ // write the low byte. Keyed off the TARGET BUFFER element type (same as bf16/FP8).
1372+ var valueReg = EnsureHardwareRegister ( value . AsNotNullCast < PrimitiveRegister > ( ) ) ;
1373+ var rawReg = AllocateRegister ( BasicValueType . Int16 , PTXRegisterKind . Int16 ) ;
1374+ EmitF32ToFP4Bits ( valueReg , rawReg ) ;
1375+ using ( var cmd = BeginCommand ( PTXInstructions . StoreOperation ) )
1376+ {
1377+ cmd . AppendAddressSpace ( targetType . AddressSpace ) ;
1378+ cmd . AppendSuffix ( "u8" ) ;
1379+ cmd . AppendArgumentValue ( address , 0 ) ;
1380+ cmd . AppendArgument ( rawReg ) ;
1381+ }
1382+ FreeRegister ( rawReg ) ;
1383+ return ;
1384+ }
1385+
1386+ // A low-precision-TYPED value (bf16 / FP8 / FP4 - all held in an f32 register on PTX) stored
1387+ // to a NON-matching buffer (the target-matching cases were handled above). The widening
1388+ // Convert (`(float)bf16` etc.) is a no-op alias that preserves the low-precision IR type, so
1389+ // `floatBuf[i] = (float)lowpBuf[i]` reaches here with a low-precision-typed value register that
1390+ // actually holds the widened f32 value. Falling through to EmitIOStore would re-narrow it
1391+ // (cvt/round + st.b16/st.b8) into the wider (e.g. 4-byte f32) destination slot -> the value
1392+ // reads back ~0 (Tuvok's "bf16 store/load returns zeros" bug; the same latent bug existed for
1393+ // FP8/FP4). Store the f32 bits directly as the target element type instead. (Struct-field
1394+ // stores keep using EmitIOStore: there the register type and field storage type agree.)
1395+ if ( value is PrimitiveRegister lowpValue &&
1396+ ( lowpValue . BasicValueType == BasicValueType . BFloat16 ||
1397+ lowpValue . BasicValueType == BasicValueType . Float8E4M3 ||
1398+ lowpValue . BasicValueType == BasicValueType . Float8E5M2 ||
1399+ lowpValue . BasicValueType == BasicValueType . Float4E2M1 ) )
1400+ {
1401+ var f32Reg = EnsureHardwareRegister ( lowpValue ) ;
11641402 using var cmd = BeginCommand ( PTXInstructions . StoreOperation ) ;
11651403 cmd . AppendAddressSpace ( targetType . AddressSpace ) ;
11661404 cmd . AppendSuffix ( ResolveIOType ( targetType . ElementType . BasicValueType ) ) ;
0 commit comments