@@ -391,6 +391,141 @@ def cleanup(self) -> None:
391391# ============================================================================
392392
393393
394+ # ---------------------------------------------------------------------------
395+ # fp8 (E4M3) / bf8 (E5M2) -- FNUZ ("NANOO") encoding used by gfx942/MI300.
396+ #
397+ # numpy has no native 8-bit float, and the C ABI only cares about the 1-byte
398+ # memory layout (sizeof(fp8_t) == sizeof(bf8_t) == 1). We carry the value as a
399+ # uint8 bit pattern. As with bf16, the DECODE is the load-bearing half: it must
400+ # return the exact value the device's fp8_t/bf8_t represents for a byte, so the
401+ # NumPy reference multiplies bit-for-bit what the GPU multiplies. The ENCODE only
402+ # needs to land on the nearest representable byte.
403+ #
404+ # FNUZ format (gfx942): bias = 2^(exp_bits-1); the all-1s exponent is a normal
405+ # number (no Inf), the sole NaN is the sign=1/exp=0/mant=0 byte (0x80), and there
406+ # is no negative zero. gfx950/MI350 uses the OCP fp8 format instead; this codec
407+ # targets the gfx942 default and the OCP path needs separate handling (the runner
408+ # raises a clear error for fp8/bf8 on a non-gfx942 arch).
409+ # ---------------------------------------------------------------------------
410+
411+
412+ def _fnuz_decode_table (exp_bits : int , mant_bits : int ) -> np .ndarray :
413+ """Build the 256-entry byte -> fp32 value table for an 8-bit FNUZ float."""
414+ bias = (1 << (exp_bits - 1 ))
415+ mant_max = 1 << mant_bits
416+ sign_shift = exp_bits + mant_bits
417+ exp_mask = (1 << exp_bits ) - 1
418+ table = np .zeros (256 , dtype = np .float32 )
419+ for b in range (256 ):
420+ sign = (b >> sign_shift ) & 1
421+ exp = (b >> mant_bits ) & exp_mask
422+ mant = b & (mant_max - 1 )
423+ if exp == 0 and mant == 0 :
424+ # +0 (0x00); the negative-zero slot (0x80) is the lone NaN.
425+ table [b ] = np .float32 (np .nan ) if sign else np .float32 (0.0 )
426+ continue
427+ if exp == 0 :
428+ val = (mant / mant_max ) * (2.0 ** (1 - bias )) # subnormal
429+ else :
430+ val = (1.0 + mant / mant_max ) * (2.0 ** (exp - bias )) # normal
431+ table [b ] = np .float32 (- val if sign else val )
432+ return table
433+
434+
435+ def _fnuz_encode (x : np .ndarray , exp_bits : int , mant_bits : int ) -> np .ndarray :
436+ """Encode fp32 -> nearest 8-bit FNUZ float, returned as a uint8 bit pattern.
437+
438+ PRESERVES the input's memory order (C or F) so a column-major operand stays
439+ column-major after encoding.
440+ """
441+ table = _fnuz_decode_table (exp_bits , mant_bits )
442+ sign_byte = np .uint8 (1 << (exp_bits + mant_bits )) # 0x80
443+
444+ # Positive half (bytes 0..127) holds every non-negative magnitude, sorted.
445+ # Compare in float64: for very large inputs the gap between the two top
446+ # magnitudes is below fp32 resolution, which would tie and mis-saturate.
447+ pos_mag = table [: int (sign_byte )].astype (np .float64 )
448+ order = np .argsort (pos_mag )
449+ sorted_mag = pos_mag [order ]
450+ sorted_byte = order .astype (np .uint8 )
451+
452+ xf = np .asarray (x , dtype = np .float32 )
453+ if not (xf .flags ["C_CONTIGUOUS" ] or xf .flags ["F_CONTIGUOUS" ]):
454+ xf = np .ascontiguousarray (xf )
455+ ax = np .abs (xf ).astype (np .float64 )
456+ # Both neighbours come from the raw insertion point: raw==size saturates to
457+ # the top magnitude (lo==hi), raw==0 pins to zero, otherwise compare the two.
458+ raw = np .searchsorted (sorted_mag , ax )
459+ hi = np .clip (raw , 0 , sorted_mag .size - 1 )
460+ lo = np .clip (raw - 1 , 0 , sorted_mag .size - 1 )
461+ pick_lo = np .abs (sorted_mag [lo ] - ax ) <= np .abs (sorted_mag [hi ] - ax )
462+ chosen = np .where (pick_lo , lo , hi )
463+ out = sorted_byte [chosen ]
464+
465+ # Apply sign, but never the 0x80 (-0 == NaN) slot: zeros stay +0.
466+ is_zero = sorted_mag [chosen ] == 0
467+ out = np .where ((xf < 0 ) & ~ is_zero , out | sign_byte , out )
468+ out = np .where (np .isnan (xf ), sign_byte , out ) # NaN inputs -> NaN byte
469+ # np.where collapses memory order; restore the operand's contiguity.
470+ out = out .astype (np .uint8 )
471+ return np .asfortranarray (out ) if xf .flags ["F_CONTIGUOUS" ] else np .ascontiguousarray (out )
472+
473+
474+ def _fp32_to_fp8_u8 (x : np .ndarray ) -> np .ndarray :
475+ """Encode fp32 -> fp8 E4M3 (FNUZ) bit pattern in a uint8 array."""
476+ return _fnuz_encode (x , exp_bits = 4 , mant_bits = 3 )
477+
478+
479+ def _fp8_u8_to_fp32 (u8 : np .ndarray ) -> np .ndarray :
480+ """Decode an fp8 E4M3 (FNUZ) bit pattern back to fp32."""
481+ return _fnuz_decode_table (4 , 3 )[u8 .astype (np .intp )]
482+
483+
484+ def _fp32_to_bf8_u8 (x : np .ndarray ) -> np .ndarray :
485+ """Encode fp32 -> bf8 E5M2 (FNUZ) bit pattern in a uint8 array."""
486+ return _fnuz_encode (x , exp_bits = 5 , mant_bits = 2 )
487+
488+
489+ def _bf8_u8_to_fp32 (u8 : np .ndarray ) -> np .ndarray :
490+ """Decode a bf8 E5M2 (FNUZ) bit pattern back to fp32."""
491+ return _fnuz_decode_table (5 , 2 )[u8 .astype (np .intp )]
492+
493+
494+ # Output (C) element dtype for an A/B element dtype, mirroring the codegen's
495+ # CommonTypeMappings.get_output_dtype: fp8/bf8 accumulate into fp16, int8 into
496+ # int32, everything else stores in its own dtype.
497+ _OUTPUT_DTYPE = {"fp8" : "fp16" , "bf8" : "fp16" , "int8" : "int32" }
498+
499+
500+ def _output_dtype (dtype : str ) -> str :
501+ return _OUTPUT_DTYPE .get (dtype , dtype )
502+
503+
504+ # numpy carrier dtype for each output (C) element type. fp8/bf8 -> fp16 store,
505+ # int8 -> int32 accumulate, bf16 carried as raw uint16 bits.
506+ _C_NP = {"fp16" : np .float16 , "bf16" : np .uint16 , "int32" : np .int32 }
507+
508+
509+ def _detect_gpu_arch () -> Optional [str ]:
510+ """Best-effort detection of the active GPU's gcnArchName (e.g. 'gfx942').
511+
512+ Parses ``rocminfo`` for the first ``gfx*`` Name line. Returns ``None`` if it
513+ cannot be determined; callers treat that as "cannot verify arch" rather than
514+ a hard failure for non-fp8 dtypes.
515+ """
516+ import re
517+ import subprocess
518+
519+ try :
520+ out = subprocess .run (
521+ ["rocminfo" ], capture_output = True , text = True , timeout = 30
522+ ).stdout
523+ except Exception :
524+ return None
525+ m = re .search (r"^\s*Name:\s*(gfx[0-9a-fA-F]+)\s*$" , out , re .MULTILINE )
526+ return m .group (1 ) if m else None
527+
528+
394529class GpuGemmRunner :
395530 """High-level runner: construct from a .so path, call run(A, B, problem).
396531
@@ -434,32 +569,69 @@ def _bf16_encode(x: np.ndarray) -> np.ndarray:
434569 def _bf16_decode (u16 : np .ndarray ) -> np .ndarray :
435570 return (u16 .astype (np .uint32 ) << 16 ).view (np .float32 )
436571
572+ # fp8/bf8 codecs are bit-exact to the device fp8_t/bf8_t (FNUZ on gfx942);
573+ # re-exposed as static methods so references (smoke test, run_one) can build
574+ # decode(encode(x)) quantized inputs without reaching into module functions.
575+ _fp8_encode = staticmethod (_fp32_to_fp8_u8 )
576+ _fp8_decode = staticmethod (_fp8_u8_to_fp32 )
577+ _bf8_encode = staticmethod (_fp32_to_bf8_u8 )
578+ _bf8_decode = staticmethod (_bf8_u8_to_fp32 )
579+
580+ def _check_arch_for_dtype (self ) -> None :
581+ """fp8/bf8 use the gfx942 FNUZ format. gfx950/MI350 uses OCP fp8, a
582+ different bit layout, so refuse rather than silently mis-decode."""
583+ if self ._dtype not in ("fp8" , "bf8" ):
584+ return
585+ arch = _detect_gpu_arch ()
586+ if arch is not None and arch != "gfx942" :
587+ raise RuntimeError (
588+ f"fp8/bf8 bridge codec is FNUZ (gfx942/MI300) only; detected "
589+ f"GPU arch { arch !r} . gfx950/MI350 uses OCP fp8 (different bit "
590+ f"layout) -- an OCP codec is required for that arch."
591+ )
592+
437593 def _to_buf (self , X : np .ndarray , major : str ) -> np .ndarray :
438594 """Lay out an operand in the order its layout implies: RowMajor ->
439595 C-contiguous, ColumnMajor -> F-contiguous. The .so reads a flat buffer
440- with the matching stride, so the raw byte order is what matters."""
596+ with the matching stride, so the raw byte order is what matters. The
597+ encode helpers (bf16/fp8/bf8) preserve that contiguity; int8/fp16 keep
598+ the requested order via astype(order='K')."""
441599 arr = np .ascontiguousarray (X ) if major == "r" else np .asfortranarray (X )
442600 if self ._dtype == "bf16" :
443601 return self ._bf16_encode (arr )
602+ if self ._dtype == "fp8" :
603+ return _fp32_to_fp8_u8 (arr )
604+ if self ._dtype == "bf8" :
605+ return _fp32_to_bf8_u8 (arr )
606+ if self ._dtype == "int8" :
607+ return arr .astype (np .int8 , order = "K" )
444608 return arr .astype (np .float16 , order = "K" )
445609
446610 def run (
447611 self , A : np .ndarray , B : np .ndarray , problem : GemmProblem
448612 ) -> GemmResult :
449613 M , N , K = problem .M , problem .N , problem .K
614+ self ._check_arch_for_dtype ()
450615
451- # Arrange A (MxK), B (KxN), C (MxN) per the kernel's actual layout. bf16 is
452- # passed as raw uint16 bits (the ctypes ABI is void*+sizeof, so 2-byte bf16
453- # and fp16 share the path; only the bit pattern differs) .
616+ # Arrange A (MxK), B (KxN), C (MxN) per the kernel's actual layout. The
617+ # ctypes ABI is void*+sizeof, so each dtype just needs the right bit
618+ # pattern: bf16 -> uint16, fp8/bf8 -> uint8, int8 -> int8, fp16 -> fp16 .
454619 la , lb , lc = self ._layout [0 ], self ._layout [1 ], self ._layout [2 ]
455620 A_h = self ._to_buf (A , la )
456621 B_h = self ._to_buf (B , lb )
457- cdt = np .uint16 if self ._dtype == "bf16" else np .float16
622+
623+ # The C buffer's element size must equal sizeof(CDataType): fp8/bf8
624+ # accumulate into fp16, int8 into int32, otherwise the input dtype (bf16
625+ # carried as raw uint16 bits).
626+ out_dtype = _output_dtype (self ._dtype )
627+ cdt = _C_NP .get (out_dtype , np .float16 )
458628 C_h = np .zeros ((M , N ), dtype = cdt , order = ("C" if lc == "r" else "F" ))
459629
460630 status , time_ms = self .lib .run (A_h , B_h , C_h , M , N , K )
461631
462- out = self ._bf16_decode (C_h ) if self ._dtype == "bf16" else C_h
632+ # Decode the output to a comparable numeric array. fp16/fp8/bf8 store fp16
633+ # (already comparable); int8 stores int32; only bf16 needs bit-decode.
634+ out = self ._bf16_decode (C_h ) if out_dtype == "bf16" else C_h
463635 tflops = (problem .flops / (time_ms * 1e-3 )) / 1e12 if time_ms > 0 else 0.0
464636 return GemmResult (
465637 output = out ,
0 commit comments