2121 fp8_per_block_quant_triton ,
2222 fp8_per_token_group_quant_triton ,
2323)
24- from .utils import QuantType , _ensure_deep_gemm
24+ from .utils import QuantType , _ensure_deep_gemm , _ensure_sgl_kernel
2525
2626FP8_MAX = float (torch .finfo (torch .float8_e4m3fn ).max )
2727FP8_MIN = float (torch .finfo (torch .float8_e4m3fn ).min )
@@ -87,6 +87,41 @@ def fp8_per_token_group_quant(
8787 )
8888
8989
90+ def fp8_per_channel_quant (weight : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
91+ """
92+ Per-channel FP8 weight quantization (E4M3 format)
93+
94+ Args:
95+ weight: Original weight tensor with shape [out_features, in_features]
96+
97+ Returns:
98+ weight_quant: Quantized weight [out_features, in_features], dtype=float8_e4m3fn
99+ weight_scale: Scale factors [out_features, 1], dtype=float32
100+ """
101+ abs_max = torch .abs (weight ).amax (dim = 1 , keepdim = True ) # [out_features, 1]
102+
103+ weight_scale = abs_max / FP8_MAX
104+ weight_scale = torch .clamp (weight_scale , min = 1e-12 )
105+
106+ weight_scaled = (weight / weight_scale ).clamp (min = FP8_MIN , max = FP8_MAX )
107+ weight_quant = weight_scaled .to (torch .float8_e4m3fn )
108+
109+ return weight_quant , weight_scale .float ()
110+
111+
112+ def fp8_per_token_quant_sgl (x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
113+ m , k = x .shape
114+ input_tensor_quant = torch .empty (
115+ (m , k ), dtype = torch .float8_e4m3fn , device = "cuda" , requires_grad = False
116+ )
117+ input_tensor_scale = torch .empty (
118+ (m , 1 ), dtype = torch .float32 , device = "cuda" , requires_grad = False
119+ )
120+ _sgl_kernel = _ensure_sgl_kernel ()
121+ _sgl_kernel .sgl_per_token_quant_fp8 (x , input_tensor_quant , input_tensor_scale )
122+ return input_tensor_quant , input_tensor_scale
123+
124+
90125# pure torch implementation of block-wise FP8 quantization on cpu
91126def fp8_per_block_quant_torch (
92127 x : torch .Tensor , block_size : int = 128
@@ -260,6 +295,35 @@ def fp8_weight_only_gemm(A, B, B_scale, bias, out_dtype):
260295 return output
261296
262297
298+ def fp8_gemm_sgl_token (A , A_scale , B , B_scale , out_dtype , bias ):
299+ """GEMM function for FP8 per-token-sgl quantization using sgl-kernel.
300+
301+ Args:
302+ A: Input activation tensor
303+ A_scale: Scale tensor for input activations
304+ B: Weight tensor
305+ B_scale: Scale tensor for weights
306+ out_dtype: Output data type.
307+ bias: Optional bias tensor
308+
309+ Returns:
310+ torch.Tensor: Result of the GEMM operation.
311+ """
312+ _sgl_kernel = _ensure_sgl_kernel ()
313+ shape = (A .shape [0 ], B .shape [0 ])
314+ output = torch .empty (shape , dtype = out_dtype , device = A .device , requires_grad = False )
315+ output = _sgl_kernel .fp8_scaled_mm (
316+ A ,
317+ B .t (),
318+ A_scale ,
319+ B_scale .float (),
320+ out_dtype ,
321+ bias = bias ,
322+ )
323+
324+ return output
325+
326+
263327def fp8_gemm (
264328 A : torch .Tensor ,
265329 A_scale : torch .Tensor ,
@@ -300,6 +364,9 @@ def fp8_gemm(
300364 if quant_type in (QuantType .FP8_PER_TENSOR , QuantType .FP8_PER_TOKEN ):
301365 # Use torch native fp8 GEMM for per-tensor and per-token fp8 quantization
302366 return fp8_gemm_torch_tensor_token (A , A_scale , B , B_scale , out_dtype , bias )
367+ elif quant_type == QuantType .FP8_PER_TOKEN_SGL :
368+ # Use sgl-kernel for per-token-sgl fp8 quantization
369+ return fp8_gemm_sgl_token (A , A_scale , B , B_scale , out_dtype , bias )
303370 elif quant_type == QuantType .FP8_PER_BLOCK :
304371 # Use deepgemm accelerated blockwise fp8 GEMM
305372 return fp8_gemm_deepgemm_block (
@@ -324,7 +391,8 @@ def fp8_gemm(
324391 f"\n native_fp8_support={ native_fp8_support } .\n "
325392 "Supported combinations:\n "
326393 " - native_fp8_support=True, "
327- "quant_type in [fp8-per-tensor, fp8-per-token, fp8-per-block]\n "
394+ "quant_type in [fp8-per-tensor, fp8-per-token,"
395+ " fp8-per-block, fp8-per-token-sgl]\n "
328396 " - native_fp8_support=False, "
329397 "quant_type in [fp8-per-tensor, fp8-per-block]"
330398 )
0 commit comments